[MM] Align the prefix of MMEncoderAttention with Attention (#33750)

Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
Shanshan Shen
2026-02-04 12:07:30 +08:00
committed by GitHub
parent 4dffc5e044
commit 9fb27dd3b3
17 changed files with 17 additions and 15 deletions

View File

@@ -130,7 +130,7 @@ class AIMv2Attention(nn.Module):
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@@ -126,7 +126,7 @@ class BlipAttention(nn.Module):
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):

View File

@@ -296,7 +296,7 @@ class Glm4vVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=prefix,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

View File

@@ -139,7 +139,7 @@ class EVA2CLIPAttention(nn.Module):
self.num_heads_per_rank,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)

View File

@@ -137,6 +137,7 @@ class GlmOcrVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

View File

@@ -166,7 +166,7 @@ class Idefics2VisionAttention(nn.Module):
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def forward(

View File

@@ -215,7 +215,7 @@ class InternParallelAttention(nn.Module):
self.num_heads_per_partition,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):

View File

@@ -220,7 +220,7 @@ class InternSdpaAttention(nn.Module):
self.num_heads,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@@ -257,7 +257,7 @@ class Llama4VisionAttention(nn.Module):
self.num_local_heads,
self.head_dim,
self.scaling,
prefix=prefix,
prefix=f"{prefix}.attn",
)
if use_data_parallel:

View File

@@ -235,7 +235,7 @@ class MultiHeadDotProductAttention(nn.Module):
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def forward(

View File

@@ -611,7 +611,7 @@ class ImagePoolingAttention(nn.Module):
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def forward_sdpa(

View File

@@ -125,6 +125,7 @@ class OpenPanguVisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

View File

@@ -345,7 +345,7 @@ class Qwen2_5_VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=prefix,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

View File

@@ -319,7 +319,7 @@ class Qwen2VisionAttention(nn.Module):
num_heads=self.num_attention_heads_per_partition,
head_size=self.hidden_size_per_attention_head,
scale=self.hidden_size_per_attention_head**-0.5,
prefix=prefix,
prefix=f"{prefix}.attn",
)
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)

View File

@@ -194,7 +194,7 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
num_heads=self.num_local_heads,
head_size=self.head_dim,
scale=self.scaling,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def forward(

View File

@@ -763,7 +763,7 @@ class Step3VisionAttention(nn.Module):
self.num_heads,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
def forward(

View File

@@ -224,7 +224,7 @@ class PerceptionEncoderVisionAttention(nn.Module):
self.num_heads,
self.head_dim,
self.scale,
prefix=prefix,
prefix=f"{prefix}.attn",
)
self.rope = PerceptionEncoderRope2D(
dim=self.head_dim,