Feature/vit attention unification# 23880 (#23978)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -255,6 +255,10 @@ class InternSdpaAttention(nn.Module):
|
||||
|
||||
self.proj = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim,
|
||||
self.scale)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x)
|
||||
@@ -268,12 +272,9 @@ class InternSdpaAttention(nn.Module):
|
||||
B_, N_, H_, D_ = q.shape
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(B_, N_, H_, D_)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
x = x.transpose(1, 2).reshape(B, N, -1)
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
x = self.attn(q, k, v)
|
||||
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user