[V1] Scatter and gather placeholders in the model runner (#16076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
This commit is contained in:
@@ -25,7 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptTargetMatch,
|
||||
PromptUpdate, PromptUpdateDetails,
|
||||
encode_tokens, find_mm_placeholders,
|
||||
find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
@@ -36,7 +36,6 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -54,14 +53,6 @@ class Gemma3ImagePixelInputs(TypedDict):
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
@@ -183,7 +174,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
image_token = processor.boi_token
|
||||
boi_token = processor.boi_token
|
||||
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
@@ -192,19 +183,21 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
)
|
||||
|
||||
if num_crops == 0:
|
||||
image_text = image_token
|
||||
image_text = boi_token
|
||||
else:
|
||||
crops_image_tokens = " ".join(image_token
|
||||
for _ in range(num_crops))
|
||||
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
|
||||
image_text = (
|
||||
f"Here is the original image {image_token} and here are some "
|
||||
f"Here is the original image {boi_token} and here are some "
|
||||
f"crops to help you see better {crops_image_tokens}")
|
||||
|
||||
repl_full = image_text.replace(image_token,
|
||||
repl_full = image_text.replace(boi_token,
|
||||
processor.full_image_sequence)
|
||||
repl_features = repl_full.strip("\n")
|
||||
|
||||
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||
tokenizer = processor.tokenizer
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
|
||||
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
@@ -213,19 +206,17 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3Processor],
|
||||
) -> int:
|
||||
tokenizer = self.get_tokenizer()
|
||||
image_repl = self.get_image_repl(
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
image_seq_len = processor.image_seq_length
|
||||
|
||||
image_repl_tokens = encode_tokens(
|
||||
tokenizer,
|
||||
image_repl.features,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return len(image_repl_tokens)
|
||||
return (num_crops + 1) * image_seq_len
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
@@ -301,28 +292,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
]
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
image_repl_features = [
|
||||
self.info.get_image_repl(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor).features
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
|
||||
embed_is_patch = [
|
||||
torch.tensor(image_repl_tokens) == image_token_id
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
num_crops = [
|
||||
self.info.get_num_crops(image_width=size.width,
|
||||
image_height=size.height,
|
||||
@@ -344,7 +313,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -454,6 +422,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
item_idx=p.item_idx,
|
||||
start_idx=repl_orig_idxs[p.start_idx],
|
||||
tokens=p.tokens,
|
||||
is_embed=p.is_embed,
|
||||
) for p in placeholders
|
||||
]
|
||||
for modality, placeholders in repls.items()
|
||||
@@ -572,7 +541,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
@@ -586,19 +554,13 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
embed_is_patch = flatten_bn(embed_is_patch)
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_patches=num_crops + 1,
|
||||
embed_is_patch=embed_is_patch,
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
@@ -635,12 +597,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
return scatter_patch_features(
|
||||
image_features,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -652,7 +609,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
multimodal_embeddings,
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
Reference in New Issue
Block a user