Feature/vit attention unification# 23880 (#23978)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -170,6 +170,7 @@ class Idefics2VisionAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
# Use unified MultiHeadAttention with Flash Attention support
|
||||
self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
self.head_dim, self.scale)
|
||||
|
||||
@@ -181,6 +182,8 @@ class Idefics2VisionAttention(nn.Module):
|
||||
hidden_states
|
||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||
|
||||
# Use unified MultiHeadAttention implementation
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
return attn_output
|
||||
|
||||
Reference in New Issue
Block a user