Feature/vit attention unification# 23880 (#23978)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
baonudesifeizhai
2025-09-10 09:10:14 -04:00
committed by GitHub
parent 72d30108a0
commit 6cbd41909e
9 changed files with 68 additions and 56 deletions

View File

@@ -255,6 +255,10 @@ class InternSdpaAttention(nn.Module):
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x)
@@ -268,12 +272,9 @@ class InternSdpaAttention(nn.Module):
B_, N_, H_, D_ = q.shape
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).reshape(B, N, -1)
# Use unified MultiHeadAttention with automatic backend selection
x = self.attn(q, k, v)
x = self.proj(x)
return x