[Models]: lfm2_siglip2 return intermediate encoder layers (#33370)

Signed-off-by: Eduardo Salinas <edus@microsoft.com>
This commit is contained in:
Eduardo Salinas
2026-02-01 01:17:49 -05:00
committed by GitHub
parent b6bb2842cf
commit 302ecf64ff

View File

@@ -22,7 +22,11 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit
from .vision import (
is_vit_use_data_parallel,
resolve_visual_encoder_outputs,
should_torch_compile_mm_vit,
)
class Siglip2VisionEmbeddings(nn.Module):
@@ -331,10 +335,17 @@ class Siglip2Encoder(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
prefix: str = "",
):
super().__init__()
self.config = config
if num_hidden_layers_override is None:
num_hidden_layers = config.num_hidden_layers
else:
num_hidden_layers = num_hidden_layers_override
self.layers = nn.ModuleList(
[
Siglip2EncoderLayer(
@@ -342,7 +353,7 @@ class Siglip2Encoder(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.layers.{idx}",
)
for idx in range(config.num_hidden_layers)
for idx in range(num_hidden_layers)
]
)
@@ -351,15 +362,21 @@ class Siglip2Encoder(nn.Module):
inputs_embeds: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: int | torch.Tensor,
) -> torch.Tensor:
return_all_hidden_states: bool = False,
) -> torch.Tensor | list[torch.Tensor]:
hidden_states_pool = [inputs_embeds]
hidden_states = inputs_embeds
for encoder_layer in self.layers:
layer_outputs = encoder_layer(
hidden_states = encoder_layer(
hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
hidden_states = layer_outputs
if return_all_hidden_states:
hidden_states_pool.append(hidden_states)
if return_all_hidden_states:
return hidden_states_pool
return hidden_states
@@ -368,6 +385,8 @@ class Siglip2VisionTransformer(nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
):
super().__init__()
@@ -381,6 +400,7 @@ class Siglip2VisionTransformer(nn.Module):
self.encoder = Siglip2Encoder(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
)
num_hidden_layers = config.num_hidden_layers
@@ -390,7 +410,13 @@ class Siglip2VisionTransformer(nn.Module):
f"layers, but you requested {len(self.encoder.layers)} layers."
)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
if require_post_norm is None:
require_post_norm = len(self.encoder.layers) == num_hidden_layers
if require_post_norm:
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
else:
self.post_layernorm = None
def get_input_embeddings(self):
return self.embeddings
@@ -401,19 +427,34 @@ class Siglip2VisionTransformer(nn.Module):
spatial_shapes: torch.LongTensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
select_layers: list[int] | None = None,
) -> torch.Tensor:
r"""
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
Tensor containing the spatial dimensions (height, width)
of the input images.
select_layers (`list[int]` or `None`, defaults to `None`):
Layer indices to select hidden states from. Supports negative
indices (e.g., -1 for last layer, -2 for second-to-last).
If None, returns the last layer output.
"""
hidden_states = self.embeddings(pixel_values_packed, spatial_shapes)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
return_all_hidden_states=select_layers is not None,
)
return self.post_layernorm(encoder_outputs)
encoder_outputs = resolve_visual_encoder_outputs(
encoder_outputs,
self.post_layernorm,
select_layers=select_layers,
max_possible_layers=self.config.num_hidden_layers,
)
return encoder_outputs
class Siglip2Model(torch.nn.Module):
@@ -421,6 +462,8 @@ class Siglip2Model(torch.nn.Module):
self,
config: Siglip2VisionConfig,
quant_config: QuantizationConfig | None = None,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
):
super().__init__()
@@ -428,6 +471,8 @@ class Siglip2Model(torch.nn.Module):
self.vision_model = Siglip2VisionTransformer(
config,
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
require_post_norm=require_post_norm,
prefix=f"{prefix}.vision_model",
)
@@ -437,12 +482,22 @@ class Siglip2Model(torch.nn.Module):
spatial_shapes: torch.LongTensor,
cu_seqlens: torch.Tensor,
max_seqlen: torch.Tensor,
select_layers: list[int] | None = None,
) -> torch.Tensor:
"""Forward pass through the vision model.
Args:
select_layers: Layer indices to select hidden states from.
Supports negative indices (e.g., [-2] for second-to-last).
If None, returns the last layer output with post_layernorm.
Multiple layers can be selected and will be concatenated.
"""
return self.vision_model(
pixel_values_packed=pixel_values_packed,
spatial_shapes=spatial_shapes,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
select_layers=select_layers,
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@@ -454,8 +509,22 @@ class Siglip2Model(torch.nn.Module):
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
layer_count = len(self.vision_model.encoder.layers)
for name, loaded_weight in weights:
# post_layernorm is optional in Siglip2Model
if (
name.startswith("vision_model.post_layernorm")
and self.vision_model.post_layernorm is None
):
continue
# omit layers when num_hidden_layers_override is set
if name.startswith("vision_model.encoder.layers"):
layer_idx = int(name.split(".")[3])
if layer_idx >= layer_count:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue