[Bugfix] Re-enable Gemma3 for V1 (#14980)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -18,7 +18,7 @@ from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.jsontree import json_map_leaves
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@@ -27,8 +27,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputs, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
MultiModalInputs, MultiModalKwargs)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@@ -44,7 +43,8 @@ from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import get_vision_encoder_info
|
||||
from .vision import (get_vision_encoder_info, scatter_patch_features,
|
||||
select_patch_features)
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
@@ -76,7 +76,7 @@ class PixtralHFImagePixelInputs(TypedDict):
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_patches: Union[torch.Tensor, list[torch.Tensor]]
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
@@ -352,15 +352,15 @@ class PixtralHFMultiModalProcessor(
|
||||
image_height=pixel_value.shape[-2],
|
||||
) for pixel_value in processed_outputs["pixel_values"]
|
||||
]
|
||||
num_patches = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
num_embeds = torch.tensor([(ncols + 1) * nrows
|
||||
for ncols, nrows in tile_sizes])
|
||||
# Each image may result to masks of different sizes, so we need to
|
||||
# later use `num_patches` to get per-image masks.
|
||||
# later use `num_embeds` to get per-image masks.
|
||||
embed_is_patch = [
|
||||
torch.tensor(([True] * ncols + [False]) * nrows)
|
||||
for ncols, nrows in tile_sizes
|
||||
]
|
||||
processed_outputs["num_patches"] = num_patches
|
||||
processed_outputs["num_embeds"] = num_embeds
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
return processed_outputs
|
||||
@@ -372,7 +372,7 @@ class PixtralHFMultiModalProcessor(
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
@@ -621,16 +621,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
num_patches = kwargs.pop("num_patches")
|
||||
if not isinstance(num_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_patches. "
|
||||
f"Got type: {type(num_patches)}")
|
||||
num_embeds = kwargs.pop("num_embeds")
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
return PixtralHFImagePixelInputs(
|
||||
type="pixel_values_pixtral",
|
||||
pixel_values=flatten_bn(pixel_values),
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_patches=num_patches,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
@@ -716,33 +716,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_embeds = torch.split(image_embeds, feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def _get_mm_embeds(
|
||||
self,
|
||||
features: torch.Tensor, # Shape: (num_patch, d)
|
||||
num_patches: torch.Tensor, # Shape: (num_images,)
|
||||
embed_is_patch: torch.Tensor, # Shape: (num_images, num_embeds)
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
Mostly copied from `Molmo._get_mm_embeds`. See following fixme comment.
|
||||
"""
|
||||
# Insert columns of nan values according to `embed_is_patch`. This work
|
||||
# ideally should be done in `_process_image_input`, but
|
||||
# `_process_image_input` is used in both V0 and V1 path. It's safer to
|
||||
# put the logic here.
|
||||
# FIXME: Move this logic to `_process_image_input` when v0 is
|
||||
# deprecated. Merge this function with `Molmo._get_mm_embeds`.
|
||||
num_patches_per_image: list[int] = num_patches.tolist()
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_patches_per_image), *features.shape[1:]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features
|
||||
|
||||
return embeds_flat.split(num_patches_per_image)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@@ -757,9 +730,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return vision_embeddings
|
||||
|
||||
return flatten_2d_lists(
|
||||
self._get_mm_embeds(*args) for args in zip(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
vision_embeddings,
|
||||
image_input["num_patches"],
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
@@ -770,16 +743,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> torch.Tensor:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
# Extract the patch tokens
|
||||
patch_embeddings = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
cast(NestedTensors, patch_embeddings),
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
Reference in New Issue
Block a user