[V1] Scatter and gather placeholders in the model runner (#16076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: mgoin <mgoin64@gmail.com>
Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
This commit is contained in:
Roger Wang
2025-04-07 19:43:41 -07:00
committed by GitHub
parent 1d01211264
commit f2ebb6f541
41 changed files with 521 additions and 1020 deletions

View File

@@ -41,7 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems,
MultiModalFieldConfig,
PromptReplacement, PromptUpdate,
encode_tokens)
PromptUpdateDetails)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -54,7 +54,6 @@ from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
from .llama import LlamaModel
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
class Idefics3ImagePixelInputs(TypedDict):
@@ -69,14 +68,6 @@ class Idefics3ImagePixelInputs(TypedDict):
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class Idefics3ImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -86,14 +77,6 @@ class Idefics3ImageEmbeddingInputs(TypedDict):
`hidden_size` must match the hidden size of language model backbone.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]
@@ -275,19 +258,16 @@ class Idefics3ProcessingInfo(BaseProcessingInfo):
image_height: int,
processor: Optional[Idefics3Processor],
) -> int:
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
if processor is None:
processor = self.get_hf_processor()
num_patches = self.get_num_patches(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
return num_patches * processor.image_seq_len
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
@@ -364,28 +344,6 @@ class Idefics3MultiModalProcessor(
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
vocab = tokenizer.get_vocab()
image_token_id = vocab[hf_processor.image_token.content]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_patches = [
self.info.get_num_patches(
image_width=size.width,
@@ -415,7 +373,6 @@ class Idefics3MultiModalProcessor(
"image", num_patches),
image_embeds=MultiModalFieldConfig.batched("image"),
num_patches=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -427,17 +384,22 @@ class Idefics3MultiModalProcessor(
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token = hf_processor.image_token.content
def get_replacement_idefics3(item_idx: int) -> str:
def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
image_repl = self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return PromptUpdateDetails.select_text(
image_repl,
embed_text=image_token,
)
return [
PromptReplacement(
modality="image",
@@ -675,13 +637,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None and image_embeds is None:
return None
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
@@ -690,7 +645,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
return Idefics3ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds, concat=True),
embed_is_patch=embed_is_patch,
)
if pixel_values is not None:
@@ -718,7 +672,6 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
pixel_values=self._validate_pixel_values(pixel_values),
pixel_attention_mask=pixel_attention_mask,
num_patches=num_patches,
embed_is_patch=embed_is_patch,
)
raise AssertionError("This line should be unreachable.")
@@ -754,12 +707,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None:
return None
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
return self._process_image_input(image_input)
def get_input_embeddings(
self,
@@ -771,7 +719,7 @@ class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
multimodal_embeddings,
self.config.image_token_id,
)
return inputs_embeds