Revert "[V1] Scatter and gather placeholders in the model runner" (#16075)
This commit is contained in:
@@ -27,8 +27,7 @@ from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, ProcessingCache,
|
||||
PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
@@ -36,7 +35,8 @@ from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||
select_patch_features)
|
||||
|
||||
|
||||
class Mistral3ImagePixelInputs(TypedDict):
|
||||
@@ -49,6 +49,14 @@ class Mistral3ImagePixelInputs(TypedDict):
|
||||
in which case the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class Mistral3PatchMerger(nn.Module):
|
||||
"""
|
||||
@@ -258,6 +266,23 @@ class Mistral3MultiModalProcessor(
|
||||
p[:, :h, :w] for p, (h, w) in zip(pixel_values, image_sizes)
|
||||
]
|
||||
|
||||
hf_config = self.info.get_hf_config()
|
||||
vision_config = hf_config.vision_config
|
||||
assert isinstance(vision_config, PixtralVisionConfig)
|
||||
encoder_info = PixtralHFEncoderInfo(vision_config)
|
||||
|
||||
tile_sizes = [
|
||||
encoder_info.get_patch_grid_size(
|
||||
image_width=pixel_value.shape[-1],
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@@ -267,6 +292,7 @@ class Mistral3MultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
@@ -301,7 +327,7 @@ class Mistral3MultiModalProcessor(
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
tokens[-1] = image_end_id
|
||||
|
||||
return PromptUpdateDetails.select_token_id(tokens, image_token_id)
|
||||
return tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@@ -392,6 +418,8 @@ def init_vision_tower_for_llava(
|
||||
)
|
||||
|
||||
|
||||
# TODO(mgoin): Support V1, there are issues with image batching/chunking
|
||||
# that need to be resolved first.
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
_build_mistral3_processor,
|
||||
info=_build_mistral3_info,
|
||||
@@ -481,9 +509,16 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
assert self.config.vision_config.model_type == "pixtral"
|
||||
embed_is_patch = kwargs.pop("embed_is_patch")
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
return Mistral3ImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
embed_is_patch=flatten_bn(embed_is_patch),
|
||||
)
|
||||
|
||||
def _process_image_input(
|
||||
@@ -522,7 +557,10 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
|
||||
return vision_embeddings
|
||||
return scatter_patch_features(
|
||||
vision_embeddings,
|
||||
image_input["embed_is_patch"],
|
||||
)
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
@@ -534,7 +572,7 @@ class Mistral3ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
multimodal_embeddings,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
Reference in New Issue
Block a user