[Model] Move vision_feature_select_strategy into resolve_visual_encoder_outputs (#25938)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-30 19:24:57 +08:00
committed by GitHub
parent ef6e0e7132
commit d7e34b4210
12 changed files with 155 additions and 179 deletions

View File

@@ -51,7 +51,8 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
from .vision import (VisionEncoderInfo, VisionFeatureSelectStrategy,
resolve_visual_encoder_outputs)
try:
from xformers import ops as xops
@@ -1218,7 +1219,9 @@ class PixtralHFVisionModel(nn.Module):
def forward(
self,
pixel_values: list[torch.Tensor],
feature_sample_layers: Optional[list[int]] = None,
*,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
) -> tuple[torch.Tensor, ...]:
"""
Args:
@@ -1226,7 +1229,7 @@ class PixtralHFVisionModel(nn.Module):
in pixel_values. This means it will be a list of tensors
because multiple requests batched can have multiple images,
each with their own shape potentially
feature_sample_layers: Layer indices whose features should be
select_layers: Layer indices whose features should be
concatenated and used as the visual encoder output. If none
are provided, the last layer is used.
@@ -1267,15 +1270,20 @@ class PixtralHFVisionModel(nn.Module):
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
patch_embeds)
return_all_hidden_states = feature_sample_layers is not None
out = self.transformer(
patch_embeds,
attention_mask,
position_embedding,
return_all_hidden_states=return_all_hidden_states)
return_all_hidden_states=select_layers is not None,
)
out = resolve_visual_encoder_outputs(out, feature_sample_layers, None,
self.config.num_hidden_layers)
out = resolve_visual_encoder_outputs(
out,
None,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
feature_select_strategy=feature_select_strategy,
)
# squeeze dim 0 and split into separate tensors for each image
return torch.split(out.squeeze(0), embed_sizes)