[Bugfix] Re-enable Gemma3 for V1 (#14980)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -9,9 +9,12 @@ from transformers import PretrainedConfig
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
|
||||
from .interfaces import MultiModalEmbeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
@@ -148,3 +151,48 @@ def resolve_visual_encoder_outputs(
|
||||
if post_layer_norm is not None and uses_last_layer:
|
||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
|
||||
def scatter_patch_features(
|
||||
features: torch.Tensor,
|
||||
num_embeds: torch.Tensor,
|
||||
embed_is_patch: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
The rest of the values in the tensor are set to NaN so that they
|
||||
can be filtered out by :func`select_patch_features`.
|
||||
|
||||
Args:
|
||||
features: The patch features, concatenated across each image.
|
||||
Shape: `(num_patch, feature_depth)`
|
||||
num_embeds: The number of image embeddings for each image.
|
||||
Shape: `(num_images,)`
|
||||
embed_is_patch: A boolean mask indicating which image embeddings
|
||||
correspond to patch tokens for each image.
|
||||
Shape: `(num_images, num_embeds)`
|
||||
"""
|
||||
num_embeds_per_image: list[int] = num_embeds.tolist()
|
||||
|
||||
embeds_flat = features.new_full(
|
||||
(sum(num_embeds_per_image), features.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embeds_flat[embed_is_patch.view(-1)] = features.flatten(0, -2)
|
||||
|
||||
return embeds_flat.split(num_embeds_per_image)
|
||||
|
||||
|
||||
def select_patch_features(
|
||||
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
|
||||
"""
|
||||
Given the outputs of :func:`scatter_patch_features`, return only
|
||||
the values that correspond to patch features.
|
||||
"""
|
||||
selected_features = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
return cast(MultiModalEmbeddings, selected_features)
|
||||
|
||||
Reference in New Issue
Block a user