[Models]: lfm2_siglip2 return intermediate encoder layers (#33370)
Signed-off-by: Eduardo Salinas <edus@microsoft.com>
This commit is contained in:
@@ -22,7 +22,11 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit
|
||||
from .vision import (
|
||||
is_vit_use_data_parallel,
|
||||
resolve_visual_encoder_outputs,
|
||||
should_torch_compile_mm_vit,
|
||||
)
|
||||
|
||||
|
||||
class Siglip2VisionEmbeddings(nn.Module):
|
||||
@@ -331,10 +335,17 @@ class Siglip2Encoder(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Siglip2EncoderLayer(
|
||||
@@ -342,7 +353,7 @@ class Siglip2Encoder(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.layers.{idx}",
|
||||
)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
for idx in range(num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -351,15 +362,21 @@ class Siglip2Encoder(nn.Module):
|
||||
inputs_embeds: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: int | torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return_all_hidden_states: bool = False,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
hidden_states_pool = [inputs_embeds]
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = layer_outputs
|
||||
if return_all_hidden_states:
|
||||
hidden_states_pool.append(hidden_states)
|
||||
if return_all_hidden_states:
|
||||
return hidden_states_pool
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -368,6 +385,8 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -381,6 +400,7 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
self.encoder = Siglip2Encoder(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
@@ -390,7 +410,13 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
f"layers, but you requested {len(self.encoder.layers)} layers."
|
||||
)
|
||||
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
if require_post_norm is None:
|
||||
require_post_norm = len(self.encoder.layers) == num_hidden_layers
|
||||
|
||||
if require_post_norm:
|
||||
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||
else:
|
||||
self.post_layernorm = None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
@@ -401,19 +427,34 @@ class Siglip2VisionTransformer(nn.Module):
|
||||
spatial_shapes: torch.LongTensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
select_layers: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`):
|
||||
Tensor containing the spatial dimensions (height, width)
|
||||
of the input images.
|
||||
select_layers (`list[int]` or `None`, defaults to `None`):
|
||||
Layer indices to select hidden states from. Supports negative
|
||||
indices (e.g., -1 for last layer, -2 for second-to-last).
|
||||
If None, returns the last layer output.
|
||||
"""
|
||||
hidden_states = self.embeddings(pixel_values_packed, spatial_shapes)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
return_all_hidden_states=select_layers is not None,
|
||||
)
|
||||
return self.post_layernorm(encoder_outputs)
|
||||
|
||||
encoder_outputs = resolve_visual_encoder_outputs(
|
||||
encoder_outputs,
|
||||
self.post_layernorm,
|
||||
select_layers=select_layers,
|
||||
max_possible_layers=self.config.num_hidden_layers,
|
||||
)
|
||||
|
||||
return encoder_outputs
|
||||
|
||||
|
||||
class Siglip2Model(torch.nn.Module):
|
||||
@@ -421,6 +462,8 @@ class Siglip2Model(torch.nn.Module):
|
||||
self,
|
||||
config: Siglip2VisionConfig,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
@@ -428,6 +471,8 @@ class Siglip2Model(torch.nn.Module):
|
||||
self.vision_model = Siglip2VisionTransformer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
require_post_norm=require_post_norm,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
)
|
||||
|
||||
@@ -437,12 +482,22 @@ class Siglip2Model(torch.nn.Module):
|
||||
spatial_shapes: torch.LongTensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
select_layers: list[int] | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through the vision model.
|
||||
|
||||
Args:
|
||||
select_layers: Layer indices to select hidden states from.
|
||||
Supports negative indices (e.g., [-2] for second-to-last).
|
||||
If None, returns the last layer output with post_layernorm.
|
||||
Multiple layers can be selected and will be concatenated.
|
||||
"""
|
||||
return self.vision_model(
|
||||
pixel_values_packed=pixel_values_packed,
|
||||
spatial_shapes=spatial_shapes,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
select_layers=select_layers,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
@@ -454,8 +509,22 @@ class Siglip2Model(torch.nn.Module):
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
layer_count = len(self.vision_model.encoder.layers)
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
# post_layernorm is optional in Siglip2Model
|
||||
if (
|
||||
name.startswith("vision_model.post_layernorm")
|
||||
and self.vision_model.post_layernorm is None
|
||||
):
|
||||
continue
|
||||
|
||||
# omit layers when num_hidden_layers_override is set
|
||||
if name.startswith("vision_model.encoder.layers"):
|
||||
layer_idx = int(name.split(".")[3])
|
||||
if layer_idx >= layer_count:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user