diff --git a/vllm/attention/layers/mm_encoder_attention.py b/vllm/attention/layers/mm_encoder_attention.py index 138fc9911..a7b442eec 100644 --- a/vllm/attention/layers/mm_encoder_attention.py +++ b/vllm/attention/layers/mm_encoder_attention.py @@ -133,6 +133,7 @@ class MMEncoderAttention(CustomOp): q=query, k=key, v=value, + scale=self.scale, cu_seqlens=cu_seqlens, ) if is_reshaped: @@ -167,6 +168,7 @@ class MMEncoderAttention(CustomOp): q=query, k=key, v=value, + scale=self.scale, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, batch_size=bsz, diff --git a/vllm/attention/ops/vit_attn_wrappers.py b/vllm/attention/ops/vit_attn_wrappers.py index 2204382a3..80c4f1491 100644 --- a/vllm/attention/ops/vit_attn_wrappers.py +++ b/vllm/attention/ops/vit_attn_wrappers.py @@ -27,6 +27,7 @@ def flash_attn_maxseqlen_wrapper( batch_size: int, is_rocm_aiter: bool, fa_version: int | None, + scale: float | None = None, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: @@ -57,6 +58,7 @@ def flash_attn_maxseqlen_wrapper( max_seqlen_k=max_seqlen, dropout_p=0.0, causal=False, + softmax_scale=scale, **kwargs, ) context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size) @@ -67,11 +69,12 @@ def flash_attn_maxseqlen_wrapper_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, - max_seqlen: torch.Tensor, batch_size: int, is_rocm_aiter: bool, fa_version: int | None, + scale: float | None, + cu_seqlens: torch.Tensor | None, + max_seqlen: torch.Tensor | None, ) -> torch.Tensor: return torch.empty_like(q) @@ -90,6 +93,7 @@ def vit_flash_attn_wrapper( batch_size: int, is_rocm_aiter: bool, fa_version: int | None, + scale: float | None = None, cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> torch.Tensor: @@ -100,18 +104,24 @@ def vit_flash_attn_wrapper( batch_size, is_rocm_aiter, fa_version, + scale, cu_seqlens, max_seqlen, ) -def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: +def apply_sdpa( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, +) -> torch.Tensor: """ Input shape: (batch_size x seq_len x num_heads x head_size) """ q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) - output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) + output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=scale) output = einops.rearrange(output, "b h s d -> b s h d ") return output @@ -122,6 +132,7 @@ def torch_sdpa_wrapper( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + scale: float | None = None, cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: # Never remove the contiguous logic for ROCm @@ -132,7 +143,7 @@ def torch_sdpa_wrapper( v = v.contiguous() if cu_seqlens is None: - return apply_sdpa(q, k, v) + return apply_sdpa(q, k, v, scale=scale) outputs = [] @@ -141,7 +152,7 @@ def torch_sdpa_wrapper( k_chunks = torch.split(k, lens, dim=1) v_chunks = torch.split(v, lens, dim=1) for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): - output_i = apply_sdpa(q_i, k_i, v_i) + output_i = apply_sdpa(q_i, k_i, v_i, scale=scale) outputs.append(output_i) context_layer = torch.cat(outputs, dim=1) return context_layer @@ -151,7 +162,8 @@ def torch_sdpa_wrapper_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - cu_seqlens: torch.Tensor, + scale: float | None, + cu_seqlens: torch.Tensor | None, ) -> torch.Tensor: return torch.empty_like(q) @@ -167,6 +179,7 @@ def vit_torch_sdpa_wrapper( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + scale: float | None = None, cu_seqlens: torch.Tensor | None = None, ) -> torch.Tensor: - return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens) + return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens) diff --git a/vllm/model_executor/models/dots_ocr.py b/vllm/model_executor/models/dots_ocr.py index 990e5e4c5..c9e0dc8b9 100644 --- a/vllm/model_executor/models/dots_ocr.py +++ b/vllm/model_executor/models/dots_ocr.py @@ -271,6 +271,7 @@ class DotsVisionAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) diff --git a/vllm/model_executor/models/ernie45_vl.py b/vllm/model_executor/models/ernie45_vl.py index 6b3a6bded..d47955ea3 100644 --- a/vllm/model_executor/models/ernie45_vl.py +++ b/vllm/model_executor/models/ernie45_vl.py @@ -152,6 +152,7 @@ class Ernie4_5_VisionAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) diff --git a/vllm/model_executor/models/glm4_1v.py b/vllm/model_executor/models/glm4_1v.py index 3a06babf2..4c4347f5a 100644 --- a/vllm/model_executor/models/glm4_1v.py +++ b/vllm/model_executor/models/glm4_1v.py @@ -304,6 +304,7 @@ class Glm4vVisionAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, ) diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 42b8d54aa..27d408afd 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -188,6 +188,7 @@ class GlmAsrEncoderAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_heads_per_rank, head_size=self.head_dim, + scale=self.head_dim**-0.5, num_kv_heads=self.num_kv_heads_per_rank, prefix=f"{prefix}.attn", ) diff --git a/vllm/model_executor/models/isaac.py b/vllm/model_executor/models/isaac.py index c95a57faf..e05df611f 100644 --- a/vllm/model_executor/models/isaac.py +++ b/vllm/model_executor/models/isaac.py @@ -984,6 +984,7 @@ class Siglip2VisionAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, prefix=f"{prefix}.attn", multimodal_config=multimodal_config, ) diff --git a/vllm/model_executor/models/moonvit.py b/vllm/model_executor/models/moonvit.py index 99200068c..c785b9910 100644 --- a/vllm/model_executor/models/moonvit.py +++ b/vllm/model_executor/models/moonvit.py @@ -390,6 +390,7 @@ class MoonVitEncoderLayer(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) diff --git a/vllm/model_executor/models/paddleocr_vl.py b/vllm/model_executor/models/paddleocr_vl.py index dc70c5a85..0e5537b86 100644 --- a/vllm/model_executor/models/paddleocr_vl.py +++ b/vllm/model_executor/models/paddleocr_vl.py @@ -564,6 +564,7 @@ class SiglipAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, prefix=f"{prefix}.attn", ) diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py index 0a38aa734..221e7bb06 100644 --- a/vllm/model_executor/models/qwen2_5_vl.py +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -352,6 +352,7 @@ class Qwen2_5_VisionAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, ) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 379e50742..ee2b6c22b 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -327,6 +327,7 @@ class Qwen2VisionAttention(nn.Module): self.attn = MMEncoderAttention( num_heads=self.num_attention_heads_per_partition, head_size=self.hidden_size_per_attention_head, + scale=self.hidden_size_per_attention_head**-0.5, multimodal_config=multimodal_config, )