[Bugfix] Re-enable Gemma3 for V1 (#14980)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-19 14:58:22 +08:00
committed by GitHub
parent 05ccd0aa35
commit 61f412187d
8 changed files with 419 additions and 175 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
import torch
from transformers import PretrainedConfig
@@ -9,9 +9,12 @@ from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.jsontree import JSONTree, json_map_leaves
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
from .interfaces import MultiModalEmbeddings
logger = init_logger(__name__)
_C = TypeVar("_C", bound=PretrainedConfig)
@@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
def scatter_patch_features(
features: torch.Tensor,
num_embeds: torch.Tensor,
embed_is_patch: torch.Tensor,
) -> tuple[torch.Tensor, ...]:
"""
Scatter the patch features into a contiguous tensor that corresponds
to the embedding tokens defined by the multimodal processor.
The rest of the values in the tensor are set to NaN so that they
can be filtered out by :func`select_patch_features`.
Args:
features: The patch features, concatenated across each image.
Shape: `(num_patch, feature_depth)`
num_embeds: The number of image embeddings for each image.
Shape: `(num_images,)`
embed_is_patch: A boolean mask indicating which image embeddings
correspond to patch tokens for each image.
Shape: `(num_images, num_embeds)`
"""
num_embeds_per_image: list[int] = num_embeds.tolist()
embeds_flat = features.new_full(
(sum(num_embeds_per_image), features.shape[-1]),
fill_value=torch.nan,
)
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
return embeds_flat.split(num_embeds_per_image)
def select_patch_features(
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
"""
Given the outputs of :func:`scatter_patch_features`, return only
the values that correspond to patch features.
"""
selected_features = json_map_leaves(
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
cast(JSONTree[torch.Tensor], multimodal_embeddings),
)
return cast(MultiModalEmbeddings, selected_features)