[MM Encoder]: Make MMEncoderAttention's scale takes effect properly (#31950)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -133,6 +133,7 @@ class MMEncoderAttention(CustomOp):
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if is_reshaped:
|
||||
@@ -167,6 +168,7 @@ class MMEncoderAttention(CustomOp):
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
batch_size=bsz,
|
||||
|
||||
@@ -27,6 +27,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -57,6 +58,7 @@ def flash_attn_maxseqlen_wrapper(
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0.0,
|
||||
causal=False,
|
||||
softmax_scale=scale,
|
||||
**kwargs,
|
||||
)
|
||||
context_layer = einops.rearrange(output, "(b s) h d -> b s h d", b=batch_size)
|
||||
@@ -67,11 +69,12 @@ def flash_attn_maxseqlen_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
max_seqlen: torch.Tensor,
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
scale: float | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
max_seqlen: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
@@ -90,6 +93,7 @@ def vit_flash_attn_wrapper(
|
||||
batch_size: int,
|
||||
is_rocm_aiter: bool,
|
||||
fa_version: int | None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
@@ -100,18 +104,24 @@ def vit_flash_attn_wrapper(
|
||||
batch_size,
|
||||
is_rocm_aiter,
|
||||
fa_version,
|
||||
scale,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
|
||||
|
||||
def apply_sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
def apply_sdpa(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Input shape:
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
|
||||
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=scale)
|
||||
output = einops.rearrange(output, "b h s d -> b s h d ")
|
||||
return output
|
||||
|
||||
@@ -122,6 +132,7 @@ def torch_sdpa_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
@@ -132,7 +143,7 @@ def torch_sdpa_wrapper(
|
||||
v = v.contiguous()
|
||||
|
||||
if cu_seqlens is None:
|
||||
return apply_sdpa(q, k, v)
|
||||
return apply_sdpa(q, k, v, scale=scale)
|
||||
|
||||
outputs = []
|
||||
|
||||
@@ -141,7 +152,7 @@ def torch_sdpa_wrapper(
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
output_i = apply_sdpa(q_i, k_i, v_i)
|
||||
output_i = apply_sdpa(q_i, k_i, v_i, scale=scale)
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
return context_layer
|
||||
@@ -151,7 +162,8 @@ def torch_sdpa_wrapper_fake(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
scale: float | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
@@ -167,6 +179,7 @@ def vit_torch_sdpa_wrapper(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, cu_seqlens)
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens)
|
||||
|
||||
@@ -271,6 +271,7 @@ class DotsVisionAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
@@ -152,6 +152,7 @@ class Ernie4_5_VisionAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
@@ -304,6 +304,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -188,6 +188,7 @@ class GlmAsrEncoderAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_heads_per_rank,
|
||||
head_size=self.head_dim,
|
||||
scale=self.head_dim**-0.5,
|
||||
num_kv_heads=self.num_kv_heads_per_rank,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
@@ -984,6 +984,7 @@ class Siglip2VisionAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
@@ -390,6 +390,7 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
@@ -564,6 +564,7 @@ class SiglipAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
multimodal_config=multimodal_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
@@ -352,6 +352,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -327,6 +327,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user