[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -577,27 +577,16 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return mm_input_by_modality
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
if strategy == "default":
|
||||
return image_features[:, 1:]
|
||||
elif strategy == "full":
|
||||
return image_features
|
||||
|
||||
raise ValueError(f"Unexpected select feature strategy: {strategy}")
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||
pixel_values: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
image_features = vision_tower(pixel_values)
|
||||
return self._select_image_features(
|
||||
image_features,
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
return vision_tower(
|
||||
pixel_values,
|
||||
feature_select_strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
@@ -750,13 +739,11 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
|
||||
pixel_values: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the vision tower
|
||||
video_features = vision_tower(pixel_values)
|
||||
video_features = self._select_image_features(
|
||||
video_features,
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
video_features = vision_tower(
|
||||
pixel_values,
|
||||
feature_select_strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
video_features = self.multi_modal_projector(video_features)
|
||||
video_features = self.apply_pooling(video_features)
|
||||
|
||||
Reference in New Issue
Block a user