[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-30 19:24:57 +08:00
committed by GitHub
parent ef6e0e7132
commit d7e34b4210
12 changed files with 155 additions and 179 deletions

View File

@@ -4,10 +4,12 @@
import itertools
import math
from abc import ABC, abstractmethod
from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union
from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
TypeVar, Union)
import torch
from transformers import PretrainedConfig
from typing_extensions import assert_never
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -86,11 +88,39 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
return current_platform.get_vit_attn_backend(head_size, dtype)
VisionFeatureSelectStrategy = Union[
Literal["class", "default", "full"],
Callable[[torch.Tensor], torch.Tensor],
]
def _get_vision_feature_selector(
strategy: VisionFeatureSelectStrategy,
) -> Callable[[torch.Tensor], torch.Tensor]:
if callable(strategy):
return strategy
# https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
if strategy == "class":
return lambda feats: feats[:, 0, :]
# https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
if strategy == "default":
return lambda feats: feats[:, 1:, :]
if strategy == "full":
return lambda feats: feats
assert_never(strategy)
def resolve_visual_encoder_outputs(
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
feature_sample_layers: Optional[list[int]],
post_layer_norm: Optional[torch.nn.LayerNorm],
max_possible_layers: int,
*,
select_layers: Optional[list[int]] = None,
max_possible_layers: Optional[int] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> torch.Tensor:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
@@ -98,17 +128,32 @@ def resolve_visual_encoder_outputs(
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
select_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
max_possible_layers: Total layers in the fully loaded visual encoder.
feature_select_strategy: Defines how to select the hidden states
from each layer.
"""
if feature_sample_layers is None:
if select_layers is None:
if not isinstance(encoder_outputs, torch.Tensor):
raise ValueError("Expected only a single encoder output when "
"`select_layers` is not provided")
if feature_select_strategy is not None:
select_features = _get_vision_feature_selector(
feature_select_strategy)
encoder_outputs = select_features(encoder_outputs)
if post_layer_norm is not None:
return post_layer_norm(encoder_outputs)
return encoder_outputs
if max_possible_layers is None:
raise ValueError("`max_possible_layers` must be provided "
"alongside `select_layers`")
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
@@ -120,13 +165,18 @@ def resolve_visual_encoder_outputs(
hs_pool = [
encoder_outputs[layer_idx]
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
for layer_idx in feature_sample_layers
for layer_idx in select_layers
]
if feature_select_strategy is not None:
select_features = _get_vision_feature_selector(feature_select_strategy)
hs_pool = [select_features(hs) for hs in hs_pool]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
hs_pool[-1] = post_layer_norm(hs_pool[-1])
return torch.cat(hs_pool, dim=-1)