Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)

This commit is contained in:
Roger Wang
2025-04-04 14:50:57 -07:00
committed by GitHub
parent f5722a5052
commit af51d80fa1
42 changed files with 942 additions and 496 deletions

View File

@@ -25,7 +25,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
find_mm_placeholders,
encode_tokens, find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
@@ -36,6 +36,7 @@ from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
@@ -53,6 +54,14 @@ class Gemma3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@@ -174,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if processor is None:
processor = self.get_hf_processor()
boi_token = processor.boi_token
image_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
@@ -183,21 +192,19 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
)
if num_crops == 0:
image_text = boi_token
image_text = image_token
else:
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
crops_image_tokens = " ".join(image_token
for _ in range(num_crops))
image_text = (
f"Here is the original image {boi_token} and here are some "
f"Here is the original image {image_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
repl_full = image_text.replace(boi_token,
repl_full = image_text.replace(image_token,
processor.full_image_sequence)
repl_features = repl_full.strip("\n")
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
return PromptUpdateDetails(full=repl_full, features=repl_features)
def get_num_image_tokens(
self,
@@ -206,17 +213,19 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
num_crops = self.get_num_crops(
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_seq_len = processor.image_seq_length
return (num_crops + 1) * image_seq_len
image_repl_tokens = encode_tokens(
tokenizer,
image_repl.features,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
@@ -292,6 +301,28 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor).features
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
@@ -313,6 +344,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -422,7 +454,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
is_embed=p.is_embed,
) for p in placeholders
]
for modality, placeholders in repls.items()
@@ -541,6 +572,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
@@ -554,13 +586,19 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
embed_is_patch = flatten_bn(embed_is_patch)
return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
)
def _image_pixels_to_features(
@@ -597,7 +635,12 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
if image_input is None:
return None
return self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@@ -609,7 +652,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds