[V1] Scatter and gather placeholders in the model runner (#16076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Roger Wang <ywang@roblox.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Jennifer Zhao <ai.jenniferzhao@gmail.com>
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union, cast
|
||||
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -10,12 +9,9 @@ from transformers import PretrainedConfig
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.jsontree import JSONTree, json_map_leaves
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
|
||||
from .interfaces import MultiModalEmbeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
@@ -155,74 +151,3 @@ def resolve_visual_encoder_outputs(
|
||||
if post_layer_norm is not None and uses_last_layer:
|
||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
|
||||
def scatter_patch_features(
|
||||
patches: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
embed_is_patch: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Scatter the patch features into a contiguous tensor that corresponds
|
||||
to the embedding tokens defined by the multimodal processor.
|
||||
|
||||
The rest of the values in the tensor are set to NaN so that they
|
||||
can be filtered out by :func`select_patch_features`.
|
||||
|
||||
Args:
|
||||
patches: The patch features for each image.
|
||||
Shape: `(num_images, <patch_dims>, feature_depth)`
|
||||
embed_is_patch: A boolean mask indicating which image embeddings
|
||||
correspond to patch tokens for each image.
|
||||
Shape: `(num_images, num_embeds)`
|
||||
|
||||
Note:
|
||||
The original code only considers patch tokens as feature
|
||||
tokens, but our processor considers all image-related tokens
|
||||
as feature tokens because the feature tokens need to be
|
||||
consecutive in `input_ids`.
|
||||
|
||||
Example:
|
||||
A simplified example for one image:
|
||||
|
||||
.. code-block::
|
||||
|
||||
Embedding tokens (from HF processor):
|
||||
[<start> <patch> <patch> <col> <patch> <patch> <col> <end> ]
|
||||
|
||||
embed_is_patch (from HF processor):
|
||||
[ False True True False True True False False ]
|
||||
|
||||
Encoder outputs (from model):
|
||||
[ p1 p2 p3 p4 ]
|
||||
|
||||
The resulting embedding tensor is:
|
||||
[ nan p1 p2 nan p3 p4 nan nan ]
|
||||
"""
|
||||
if len(patches) != len(embed_is_patch):
|
||||
raise ValueError(f"Inconsistent num_images: {len(patches)=} vs. "
|
||||
f"{len(embed_is_patch)=}")
|
||||
|
||||
def get_embed_one(patches_one: torch.Tensor, e_is_patch: torch.Tensor):
|
||||
embed_one = patches_one.new_full(
|
||||
(e_is_patch.shape[0], patches_one.shape[-1]),
|
||||
fill_value=torch.nan,
|
||||
)
|
||||
embed_one[e_is_patch] = patches_one
|
||||
return embed_one
|
||||
|
||||
return tuple(
|
||||
get_embed_one(patches_one, e_is_patch)
|
||||
for patches_one, e_is_patch in zip(patches, embed_is_patch))
|
||||
|
||||
|
||||
def select_patch_features(
|
||||
multimodal_embeddings: MultiModalEmbeddings) -> MultiModalEmbeddings:
|
||||
"""
|
||||
Given the outputs of :func:`scatter_patch_features`, return only
|
||||
the values that correspond to patch features.
|
||||
"""
|
||||
selected_features = json_map_leaves(
|
||||
lambda x: x[~x.isnan()].view(-1, *x.shape[1:]),
|
||||
cast(JSONTree[torch.Tensor], multimodal_embeddings),
|
||||
)
|
||||
return cast(MultiModalEmbeddings, selected_features)
|
||||
|
||||
Reference in New Issue
Block a user