[bugfix] fix siglip batch text output error (#28365)
Signed-off-by: piood <2477084691@qq.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from transformers import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
from vllm.attention.layer import MultiHeadAttention
|
||||||
|
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.multimodal import BaseDummyOptions
|
from vllm.config.multimodal import BaseDummyOptions
|
||||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||||
@@ -379,6 +380,7 @@ class SiglipAttention(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -413,8 +415,11 @@ class SiglipAttention(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||||
|
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = attn_cls(
|
||||||
self.num_heads_per_partition, self.head_dim, self.scale
|
self.num_heads_per_partition,
|
||||||
|
self.head_dim,
|
||||||
|
self.scale,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -424,25 +429,7 @@ class SiglipAttention(nn.Module):
|
|||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||||
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
|
||||||
|
|
||||||
needs_unsqueeze = query_states.ndim == 2
|
|
||||||
if needs_unsqueeze:
|
|
||||||
query_states, key_states, value_states = (
|
|
||||||
query_states.unsqueeze(0),
|
|
||||||
key_states.unsqueeze(0),
|
|
||||||
value_states.unsqueeze(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
out = self.attn(query_states, key_states, value_states)
|
out = self.attn(query_states, key_states, value_states)
|
||||||
|
|
||||||
if needs_unsqueeze:
|
|
||||||
out, query_states, key_states, value_states = (
|
|
||||||
out.squeeze(0),
|
|
||||||
query_states.squeeze(0),
|
|
||||||
key_states.squeeze(0),
|
|
||||||
value_states.squeeze(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
|
|
||||||
return attn_output, None
|
return attn_output, None
|
||||||
@@ -495,6 +482,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
quant_config: QuantizationConfig | None = None,
|
quant_config: QuantizationConfig | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -504,6 +492,7 @@ class SiglipEncoderLayer(nn.Module):
|
|||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
attn_cls=attn_cls,
|
||||||
)
|
)
|
||||||
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||||
self.mlp = SiglipMLP(
|
self.mlp = SiglipMLP(
|
||||||
@@ -539,6 +528,7 @@ class SiglipEncoder(nn.Module):
|
|||||||
num_hidden_layers_override: int | None = None,
|
num_hidden_layers_override: int | None = None,
|
||||||
*,
|
*,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -555,6 +545,7 @@ class SiglipEncoder(nn.Module):
|
|||||||
config,
|
config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.layers.{layer_idx}",
|
prefix=f"{prefix}.layers.{layer_idx}",
|
||||||
|
attn_cls=attn_cls,
|
||||||
)
|
)
|
||||||
for layer_idx in range(num_hidden_layers)
|
for layer_idx in range(num_hidden_layers)
|
||||||
]
|
]
|
||||||
@@ -598,6 +589,7 @@ class SiglipTextTransformer(nn.Module):
|
|||||||
config=config,
|
config=config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.encoder",
|
prefix=f"{prefix}.encoder",
|
||||||
|
attn_cls=EncoderOnlyAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
||||||
@@ -709,6 +701,7 @@ class SiglipVisionTransformer(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
num_hidden_layers_override=num_hidden_layers_override,
|
num_hidden_layers_override=num_hidden_layers_override,
|
||||||
prefix=f"{prefix}.encoder",
|
prefix=f"{prefix}.encoder",
|
||||||
|
attn_cls=MultiHeadAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_hidden_layers = config.num_hidden_layers
|
num_hidden_layers = config.num_hidden_layers
|
||||||
@@ -1034,10 +1027,56 @@ class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
|
|||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
text_features = self.text_model.head(last_hidden_state)
|
text_features = self.text_model.head(last_hidden_state)
|
||||||
# Flip to extract CLS token (first token after reversal) for pooling
|
|
||||||
text_features = text_features.flip(0)
|
# SigLIP uses reversed position_ids;
|
||||||
|
# flip sequences to move EOS token to first position
|
||||||
|
text_features = self._flip_sequences_by_position_ids(
|
||||||
|
text_features, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
return text_features
|
return text_features
|
||||||
|
|
||||||
|
def _flip_sequences_by_position_ids(
|
||||||
|
self,
|
||||||
|
features: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Flip sequences so EOS token moves to first position for CLS pooling.
|
||||||
|
|
||||||
|
SigLIP position_ids are reversed within each sequence. This method detects
|
||||||
|
sequence boundaries and flips each sequence individually.
|
||||||
|
"""
|
||||||
|
if len(features) == 1:
|
||||||
|
return features
|
||||||
|
|
||||||
|
# Detect sequence boundaries where position_ids decrease
|
||||||
|
position_diffs = position_ids[1:] - position_ids[:-1]
|
||||||
|
boundary_mask = position_diffs <= 0
|
||||||
|
|
||||||
|
boundary_indices = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([0], device=features.device),
|
||||||
|
torch.where(boundary_mask)[0] + 1,
|
||||||
|
torch.tensor([len(features)], device=features.device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each sequence [start, end), position i flips to: start + end - 1 - i
|
||||||
|
lengths = boundary_indices[1:] - boundary_indices[:-1]
|
||||||
|
starts = boundary_indices[:-1]
|
||||||
|
ends = boundary_indices[1:]
|
||||||
|
|
||||||
|
# Assign sequence ID to each element
|
||||||
|
sequence_ids = torch.arange(
|
||||||
|
len(lengths), device=features.device
|
||||||
|
).repeat_interleave(lengths)
|
||||||
|
|
||||||
|
# Calculate flipped indices for all positions at once
|
||||||
|
current_positions = torch.arange(len(features), device=features.device)
|
||||||
|
flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions
|
||||||
|
|
||||||
|
return features[flip_indices]
|
||||||
|
|
||||||
def get_image_features(
|
def get_image_features(
|
||||||
self,
|
self,
|
||||||
pixel_values: torch.Tensor,
|
pixel_values: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user