Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)

This commit is contained in:
Roger Wang
2025-04-04 14:50:57 -07:00
committed by GitHub
parent f5722a5052
commit af51d80fa1
42 changed files with 942 additions and 496 deletions

View File

@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
encode_tokens)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -54,6 +54,7 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict):
@@ -68,6 +69,14 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -77,6 +86,14 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
@@ -258,16 +275,19 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
num_patches = self.get_num_patches(
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
return num_patches * processor.image_seq_len
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
@@ -344,6 +364,28 @@ class Idefics3MultiModalProcessor(
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [
self.info.get_num_patches(
image_width=size.width,
@@ -373,6 +415,7 @@ class Idefics3MultiModalProcessor(
"image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -384,22 +427,17 @@ class Idefics3MultiModalProcessor(
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
def get_replacement_idefics3(item_idx: int) -> str:
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
image_repl = self.info.get_image_repl(
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return PromptUpdateDetails.select_text(
image_repl,
embed_text=image_token,
)
return [
PromptReplacement(
modality="image",
@@ -637,6 +675,13 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None and image_embeds is None:
return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
@@ -645,6 +690,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
)
if pixel_values is not None:
@@ -672,6 +718,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
@@ -707,7 +754,12 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None:
return None
return self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@@ -719,7 +771,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
select_patch_features(multimodal_embeddings),
self.config.image_token_id,
)
return inputs_embeds