[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -23,7 +23,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
|
||||
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
|
||||
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
|
||||
resolve_visual_encoder_outputs)
|
||||
|
||||
|
||||
class SiglipEncoderInfo(VisionEncoderInfo[SiglipVisionConfig]):
|
||||
@@ -415,28 +416,31 @@ class SiglipVisionTransformer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
interpolate_pos_encoding: bool = True,
|
||||
feature_sample_layers: Optional[list[int]] = None,
|
||||
*,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
select_layers: Optional[list[int]] = None,
|
||||
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
)
|
||||
|
||||
return_all_hidden_states = feature_sample_layers is not None
|
||||
|
||||
# Produces either the last layer output or all of the hidden states,
|
||||
# depending on if we have feature_sample_layers or not
|
||||
# depending on if we have select_layers or not
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
return_all_hidden_states=return_all_hidden_states,
|
||||
return_all_hidden_states=select_layers is not None,
|
||||
)
|
||||
|
||||
# Handle post-norm (if applicable) and stacks feature layers if needed
|
||||
encoder_outputs = resolve_visual_encoder_outputs(
|
||||
encoder_outputs, feature_sample_layers, self.post_layernorm,
|
||||
self.config.num_hidden_layers)
|
||||
encoder_outputs,
|
||||
self.post_layernorm,
|
||||
select_layers=select_layers,
|
||||
max_possible_layers=self.config.num_hidden_layers,
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
# TODO: add this back when pooled_output is used in inference.
|
||||
# if self.use_head:
|
||||
@@ -471,16 +475,22 @@ class SiglipVisionModel(nn.Module):
|
||||
def get_input_embeddings(self) -> nn.Module:
|
||||
return self.vision_model.embeddings.patch_embedding
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.get_input_embeddings().weight.dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
feature_sample_layers: Optional[list[int]] = None,
|
||||
select_layers: Optional[list[int]] = None,
|
||||
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
interpolate_pos_encoding=interpolate_pos_encoding,
|
||||
feature_sample_layers=feature_sample_layers,
|
||||
select_layers=select_layers,
|
||||
feature_select_strategy=feature_select_strategy,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
||||
Reference in New Issue
Block a user