[Misc] Factor out common _apply_feature_select_strategy (#26003)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -9,7 +9,6 @@ from typing import (Callable, Final, Generic, Literal, Optional, Protocol,
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -22,9 +21,13 @@ logger = init_logger(__name__)
|
||||
_C = TypeVar("_C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
class _RootConfig(Protocol[_C]):
|
||||
vision_config: _C
|
||||
|
||||
|
||||
class VisionEncoderInfo(ABC, Generic[_C]):
|
||||
|
||||
def __init__(self, hf_config: _C) -> None:
|
||||
def __init__(self, hf_config: _RootConfig[_C]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.hf_config = hf_config
|
||||
@@ -95,7 +98,7 @@ VisionFeatureSelectStrategy = Union[
|
||||
|
||||
|
||||
def _get_vision_feature_selector(
|
||||
strategy: VisionFeatureSelectStrategy,
|
||||
strategy: Union[VisionFeatureSelectStrategy, str],
|
||||
) -> Callable[[torch.Tensor], torch.Tensor]:
|
||||
if callable(strategy):
|
||||
return strategy
|
||||
@@ -111,7 +114,28 @@ def _get_vision_feature_selector(
|
||||
if strategy == "full":
|
||||
return lambda feats: feats
|
||||
|
||||
assert_never(strategy)
|
||||
raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
|
||||
|
||||
|
||||
def get_num_selected_vision_tokens(
|
||||
num_vision_tokens: int,
|
||||
strategy: Union[VisionFeatureSelectStrategy, str],
|
||||
) -> int:
|
||||
if callable(strategy):
|
||||
dummy_features = torch.empty(1, num_vision_tokens, 64) # [B, L, D]
|
||||
dummy_selected_features = strategy(dummy_features)
|
||||
return dummy_selected_features.shape[1]
|
||||
|
||||
if strategy == "class":
|
||||
return 1
|
||||
|
||||
if strategy == "default":
|
||||
return num_vision_tokens - 1
|
||||
|
||||
if strategy == "full":
|
||||
return num_vision_tokens
|
||||
|
||||
raise ValueError(f"Unexpected feature select strategy: {strategy!r}")
|
||||
|
||||
|
||||
def resolve_visual_encoder_outputs(
|
||||
|
||||
Reference in New Issue
Block a user