[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,10 +4,12 @@
|
||||
import itertools
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Final, Generic, Literal, Optional, Protocol, TypeVar, Union
|
||||
from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -86,11 +88,39 @@ def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
return current_platform.get_vit_attn_backend(head_size, dtype)
|
||||
|
||||
|
||||
VisionFeatureSelectStrategy = Union[
|
||||
Literal["class", "default", "full"],
|
||||
Callable[[torch.Tensor], torch.Tensor],
|
||||
]
|
||||
|
||||
|
||||
def _get_vision_feature_selector(
|
||||
strategy: VisionFeatureSelectStrategy,
|
||||
) -> Callable[[torch.Tensor], torch.Tensor]:
|
||||
if callable(strategy):
|
||||
return strategy
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/cd74917ffc3e8f84e4a886052c5ab32b7ac623cc/src/transformers/models/clip/modeling_clip.py#L762
|
||||
if strategy == "class":
|
||||
return lambda feats: feats[:, 0, :]
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/4a02bc7004285bdb12cc033e87ad2578ce2fa900/src/transformers/models/llava/modeling_llava.py#L196
|
||||
if strategy == "default":
|
||||
return lambda feats: feats[:, 1:, :]
|
||||
|
||||
if strategy == "full":
|
||||
return lambda feats: feats
|
||||
|
||||
assert_never(strategy)
|
||||
|
||||
|
||||
def resolve_visual_encoder_outputs(
|
||||
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
|
||||
feature_sample_layers: Optional[list[int]],
|
||||
post_layer_norm: Optional[torch.nn.LayerNorm],
|
||||
max_possible_layers: int,
|
||||
*,
|
||||
select_layers: Optional[list[int]] = None,
|
||||
max_possible_layers: Optional[int] = None,
|
||||
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Given the outputs a visual encoder module that may correspond to the
|
||||
output of the last layer, or a list of hidden states to be stacked,
|
||||
@@ -98,17 +128,32 @@ def resolve_visual_encoder_outputs(
|
||||
|
||||
Args:
|
||||
encoder_outputs: Output of encoder's last layer or all hidden states.
|
||||
feature_sample_layers: Optional layer indices to grab from the encoder
|
||||
outputs; if provided, encoder outputs must be a list.
|
||||
post_layer_norm: Post norm to apply to the output of the encoder.
|
||||
select_layers: Optional layer indices to grab from the encoder
|
||||
outputs; if provided, encoder outputs must be a list.
|
||||
max_possible_layers: Total layers in the fully loaded visual encoder.
|
||||
|
||||
feature_select_strategy: Defines how to select the hidden states
|
||||
from each layer.
|
||||
"""
|
||||
if feature_sample_layers is None:
|
||||
if select_layers is None:
|
||||
if not isinstance(encoder_outputs, torch.Tensor):
|
||||
raise ValueError("Expected only a single encoder output when "
|
||||
"`select_layers` is not provided")
|
||||
|
||||
if feature_select_strategy is not None:
|
||||
select_features = _get_vision_feature_selector(
|
||||
feature_select_strategy)
|
||||
encoder_outputs = select_features(encoder_outputs)
|
||||
|
||||
if post_layer_norm is not None:
|
||||
return post_layer_norm(encoder_outputs)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
if max_possible_layers is None:
|
||||
raise ValueError("`max_possible_layers` must be provided "
|
||||
"alongside `select_layers`")
|
||||
|
||||
# Get the hidden states corresponding to the layer indices.
|
||||
# Negative values are relative to the full visual encoder,
|
||||
# so offset them depending on how many layers were loaded.
|
||||
@@ -120,13 +165,18 @@ def resolve_visual_encoder_outputs(
|
||||
hs_pool = [
|
||||
encoder_outputs[layer_idx]
|
||||
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
|
||||
for layer_idx in feature_sample_layers
|
||||
for layer_idx in select_layers
|
||||
]
|
||||
|
||||
if feature_select_strategy is not None:
|
||||
select_features = _get_vision_feature_selector(feature_select_strategy)
|
||||
hs_pool = [select_features(hs) for hs in hs_pool]
|
||||
|
||||
# Apply post-norm on the final hidden state if we are using it
|
||||
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
|
||||
uses_last_layer = select_layers[-1] in (max_possible_layers - 1, -1)
|
||||
if post_layer_norm is not None and uses_last_layer:
|
||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||
hs_pool[-1] = post_layer_norm(hs_pool[-1])
|
||||
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user