Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)

This commit is contained in:
Roger Wang
2025-04-04 14:50:57 -07:00
committed by GitHub
parent f5722a5052
commit af51d80fa1
42 changed files with 942 additions and 496 deletions

View File

@@ -32,8 +32,7 @@ from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -43,7 +42,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
from .vision import (get_vision_encoder_info, scatter_patch_features,
select_patch_features)
class LlavaImagePixelInputs(TypedDict):
@@ -67,6 +67,14 @@ class PixtralHFImagePixelInputs(TypedDict):
in which case the data is passed as a list instead of a batched tensor.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
@@ -335,6 +343,23 @@ class PixtralHFMultiModalProcessor(
for p, (h, w) in zip(pixel_values, image_sizes)
]
hf_config = self.info.get_hf_config()
vision_config = hf_config.vision_config
assert isinstance(vision_config, PixtralVisionConfig)
encoder_info = PixtralHFEncoderInfo(vision_config)
tile_sizes = [
encoder_info.get_patch_grid_size(
image_width=pixel_value.shape[-1],
image_height=pixel_value.shape[-2],
) for pixel_value in processed_outputs["pixel_values"]
]
embed_is_patch = [
torch.tensor(([True] * ncols + [False]) * nrows)
for ncols, nrows in tile_sizes
]
processed_outputs["embed_is_patch"] = embed_is_patch
return processed_outputs
def _get_mm_fields_config(
@@ -344,6 +369,7 @@ class PixtralHFMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -378,7 +404,7 @@ class PixtralHFMultiModalProcessor(
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return tokens
return [
PromptReplacement(
@@ -586,9 +612,17 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
f"Got type: {type(pixel_values)}")
if self.config.vision_config.model_type == "pixtral":
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralHFImagePixelInputs(
type="pixel_values_pixtral",
pixel_values=flatten_bn(pixel_values),
embed_is_patch=embed_is_patch,
)
return LlavaImagePixelInputs(
@@ -680,7 +714,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
if image_input is None:
return None
return self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
if image_input["type"] != "pixel_values_pixtral":
# The path is used for pixtral (V0 only) and llava (V0/V1)
return image_features
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@@ -692,7 +735,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds