diff --git a/vllm/model_executor/models/lfm2_siglip2.py b/vllm/model_executor/models/lfm2_siglip2.py index d58e2ad85..92ea42f27 100644 --- a/vllm/model_executor/models/lfm2_siglip2.py +++ b/vllm/model_executor/models/lfm2_siglip2.py @@ -22,7 +22,11 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from .vision import is_vit_use_data_parallel, should_torch_compile_mm_vit +from .vision import ( + is_vit_use_data_parallel, + resolve_visual_encoder_outputs, + should_torch_compile_mm_vit, +) class Siglip2VisionEmbeddings(nn.Module): @@ -331,10 +335,17 @@ class Siglip2Encoder(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, prefix: str = "", ): super().__init__() self.config = config + + if num_hidden_layers_override is None: + num_hidden_layers = config.num_hidden_layers + else: + num_hidden_layers = num_hidden_layers_override + self.layers = nn.ModuleList( [ Siglip2EncoderLayer( @@ -342,7 +353,7 @@ class Siglip2Encoder(nn.Module): quant_config=quant_config, prefix=f"{prefix}.layers.{idx}", ) - for idx in range(config.num_hidden_layers) + for idx in range(num_hidden_layers) ] ) @@ -351,15 +362,21 @@ class Siglip2Encoder(nn.Module): inputs_embeds: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int | torch.Tensor, - ) -> torch.Tensor: + return_all_hidden_states: bool = False, + ) -> torch.Tensor | list[torch.Tensor]: + hidden_states_pool = [inputs_embeds] hidden_states = inputs_embeds + for encoder_layer in self.layers: - layer_outputs = encoder_layer( + hidden_states = encoder_layer( hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, ) - hidden_states = layer_outputs + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + if return_all_hidden_states: + return hidden_states_pool return hidden_states @@ -368,6 +385,8 @@ class Siglip2VisionTransformer(nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ): super().__init__() @@ -381,6 +400,7 @@ class Siglip2VisionTransformer(nn.Module): self.encoder = Siglip2Encoder( config, quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, prefix=f"{prefix}.encoder", ) num_hidden_layers = config.num_hidden_layers @@ -390,7 +410,13 @@ class Siglip2VisionTransformer(nn.Module): f"layers, but you requested {len(self.encoder.layers)} layers." ) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + if require_post_norm is None: + require_post_norm = len(self.encoder.layers) == num_hidden_layers + + if require_post_norm: + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + else: + self.post_layernorm = None def get_input_embeddings(self): return self.embeddings @@ -401,19 +427,34 @@ class Siglip2VisionTransformer(nn.Module): spatial_shapes: torch.LongTensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, + select_layers: list[int] | None = None, ) -> torch.Tensor: r""" spatial_shapes (`torch.LongTensor` of shape `(batch_size, 2)`): Tensor containing the spatial dimensions (height, width) of the input images. + select_layers (`list[int]` or `None`, defaults to `None`): + Layer indices to select hidden states from. Supports negative + indices (e.g., -1 for last layer, -2 for second-to-last). + If None, returns the last layer output. """ hidden_states = self.embeddings(pixel_values_packed, spatial_shapes) + encoder_outputs = self.encoder( inputs_embeds=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + return_all_hidden_states=select_layers is not None, ) - return self.post_layernorm(encoder_outputs) + + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, + self.post_layernorm, + select_layers=select_layers, + max_possible_layers=self.config.num_hidden_layers, + ) + + return encoder_outputs class Siglip2Model(torch.nn.Module): @@ -421,6 +462,8 @@ class Siglip2Model(torch.nn.Module): self, config: Siglip2VisionConfig, quant_config: QuantizationConfig | None = None, + num_hidden_layers_override: int | None = None, + require_post_norm: bool | None = None, prefix: str = "", ): super().__init__() @@ -428,6 +471,8 @@ class Siglip2Model(torch.nn.Module): self.vision_model = Siglip2VisionTransformer( config, quant_config=quant_config, + num_hidden_layers_override=num_hidden_layers_override, + require_post_norm=require_post_norm, prefix=f"{prefix}.vision_model", ) @@ -437,12 +482,22 @@ class Siglip2Model(torch.nn.Module): spatial_shapes: torch.LongTensor, cu_seqlens: torch.Tensor, max_seqlen: torch.Tensor, + select_layers: list[int] | None = None, ) -> torch.Tensor: + """Forward pass through the vision model. + + Args: + select_layers: Layer indices to select hidden states from. + Supports negative indices (e.g., [-2] for second-to-last). + If None, returns the last layer output with post_layernorm. + Multiple layers can be selected and will be concatenated. + """ return self.vision_model( pixel_values_packed=pixel_values_packed, spatial_shapes=spatial_shapes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + select_layers=select_layers, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -454,8 +509,22 @@ class Siglip2Model(torch.nn.Module): ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() + layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: + # post_layernorm is optional in Siglip2Model + if ( + name.startswith("vision_model.post_layernorm") + and self.vision_model.post_layernorm is None + ): + continue + + # omit layers when num_hidden_layers_override is set + if name.startswith("vision_model.encoder.layers"): + layer_idx = int(name.split(".")[3]) + if layer_idx >= layer_count: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue