[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-30 19:24:57 +08:00
committed by GitHub
parent ef6e0e7132
commit d7e34b4210
12 changed files with 155 additions and 179 deletions

View File

@@ -577,27 +577,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return mm_input_by_modality
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
if strategy == "default":
return image_features[:, 1:]
elif strategy == "full":
return image_features
raise ValueError(f"Unexpected select feature strategy: {strategy}")
def _image_pixels_to_features(
self,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
strategy=self.config.vision_feature_select_strategy,
return vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
@@ -750,13 +739,11 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
pixel_values: torch.Tensor,
) -> torch.Tensor:
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
video_features = vision_tower(pixel_values)
video_features = self._select_image_features(
video_features,
strategy=self.config.vision_feature_select_strategy,
video_features = vision_tower(
pixel_values,
feature_select_strategy=self.config.vision_feature_select_strategy,
)
video_features = self.multi_modal_projector(video_features)
video_features = self.apply_pooling(video_features)