[MM] Pass prefix parameter to MMEncoderAttention (#33674)
Signed-off-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@@ -127,7 +127,10 @@ class AIMv2Attention(nn.Module):
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -123,7 +123,10 @@ class BlipAttention(nn.Module):
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
|
||||
@@ -296,6 +296,7 @@ class Glm4vVisionAttention(nn.Module):
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
@@ -136,7 +136,10 @@ class EVA2CLIPAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_rank, self.head_dim, self.scale
|
||||
self.num_heads_per_rank,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||
|
||||
|
||||
@@ -163,7 +163,10 @@ class Idefics2VisionAttention(nn.Module):
|
||||
)
|
||||
# Use unified MMEncoderAttention with Flash Attention support
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -212,7 +212,10 @@ class InternParallelAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
self.num_heads_per_partition,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor):
|
||||
|
||||
@@ -170,6 +170,7 @@ class InternSdpaAttention(nn.Module):
|
||||
config: PretrainedConfig,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -215,7 +216,12 @@ class InternSdpaAttention(nn.Module):
|
||||
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""x shape: (B, N, C)"""
|
||||
@@ -313,7 +319,11 @@ class InternS1VisionLayer(nn.Module):
|
||||
num_dummy_heads: int,
|
||||
prefix: str = "",
|
||||
):
|
||||
return InternSdpaAttention(config, num_dummy_heads=num_dummy_heads)
|
||||
return InternSdpaAttention(
|
||||
config,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -254,7 +254,10 @@ class Llama4VisionAttention(nn.Module):
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_local_heads, self.head_dim, self.scaling
|
||||
self.num_local_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
if use_data_parallel:
|
||||
|
||||
@@ -231,7 +231,11 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -611,6 +611,7 @@ class ImagePoolingAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward_sdpa(
|
||||
|
||||
@@ -345,6 +345,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
@@ -319,6 +319,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
num_heads=self.num_attention_heads_per_partition,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
scale=self.hidden_size_per_attention_head**-0.5,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
self.apply_rotary_emb = ApplyRotaryEmb(enforce_enable=True)
|
||||
|
||||
@@ -194,6 +194,7 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scaling,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -759,7 +759,12 @@ class Step3VisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -220,7 +220,12 @@ class PerceptionEncoderVisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scale,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.rope = PerceptionEncoderRope2D(
|
||||
dim=self.head_dim,
|
||||
max_grid_height=max_grid_height,
|
||||
|
||||
Reference in New Issue
Block a user