Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)

This commit is contained in:
Roger Wang
2025-04-04 14:50:57 -07:00
committed by GitHub
parent f5722a5052
commit af51d80fa1
42 changed files with 942 additions and 496 deletions

View File

@@ -37,7 +37,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
@@ -46,7 +46,8 @@ from vllm.transformers_utils.tokenizer import (MistralTokenizer,
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, resolve_visual_encoder_outputs,
scatter_patch_features, select_patch_features)
try:
from xformers import ops as xops
@@ -67,6 +68,14 @@ class PixtralImagePixelInputs(TypedDict):
The result of stacking :attr:`ImageEncoding.tokens` from each prompt.
"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size * num_images, num_embeds)`
"""
class PixtralProcessorAdapter:
"""
@@ -135,8 +144,11 @@ class PixtralProcessorAdapter:
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
image_token_id = self.image_token_id
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
images_embed_is_patch = list[torch.Tensor]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
@@ -145,10 +157,12 @@ class PixtralProcessorAdapter:
images_processed.append(image_processed)
images_tokens.append(image_tokens)
images_embed_is_patch.append(image_tokens == image_token_id)
return {
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
"embed_is_patch": images_embed_is_patch,
}
@@ -199,7 +213,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height)))
return ncols * nrows
return (ncols + 1) * nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
@@ -249,7 +263,10 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
hf_inputs: Mapping[str, NestedTensors],
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(images=MultiModalFieldConfig.batched("image"))
return dict(
images=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
@@ -273,7 +290,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
tokens[-1] = image_end_id
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
return tokens
return [
PromptReplacement(
@@ -364,9 +381,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
raise ValueError("Incorrect type of images. "
f"Got type: {type(images)}")
embed_is_patch = kwargs.pop("embed_is_patch")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
embed_is_patch = flatten_bn(embed_is_patch)
return PixtralImagePixelInputs(
type="pixel_values",
images=flatten_bn(images),
embed_is_patch=embed_is_patch,
)
def _process_image_input(
@@ -402,7 +427,12 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if image_input is None:
return None
return self._process_image_input(image_input)
image_features = self._process_image_input(image_input)
return scatter_patch_features(
image_features,
image_input["embed_is_patch"],
)
def get_input_embeddings(
self,
@@ -414,7 +444,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal,
inputs_embeds = merge_multimodal_embeddings(
input_ids,
inputs_embeds,
multimodal_embeddings,
select_patch_features(multimodal_embeddings),
self.vision_args.image_token_id,
)
return inputs_embeds
@@ -933,7 +963,9 @@ class PixtralHFEncoderInfo(VisionEncoderInfo[PixtralVisionConfig]):
image_width=image_width,
image_height=image_height,
)
return ncols * nrows
# Consider the image_break_token
return (ncols + 1) * nrows
def get_max_image_tokens(self) -> int:
image_size = self.get_image_size()