[Model][QwenVL] Optimize Qwen2_5_VisionAttention q,k preparation (#28769)
Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -359,23 +359,6 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (
|
||||
seq_len,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@@ -386,17 +369,32 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
seq_len, batch_size, _ = x.shape
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
qkv = einops.rearrange(
|
||||
x,
|
||||
"s b (three head head_dim) -> b s three head head_dim",
|
||||
three=3,
|
||||
head=self.num_attention_heads_per_partition,
|
||||
)
|
||||
|
||||
q, k, v = (einops.rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
# [2 * b, s, heads, head_dim]
|
||||
qk_concat = torch.cat([q, k], dim=0)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_concat, rotary_pos_emb)
|
||||
q, k = torch.chunk(qk_rotated, 2, dim=0)
|
||||
qk, v = qkv[:, :, :2], qkv[:, :, 2]
|
||||
|
||||
qk_reshaped = einops.rearrange(
|
||||
qk, "b s two head head_dim -> (two b) s head head_dim", two=2
|
||||
)
|
||||
qk_rotated = apply_rotary_pos_emb_vision(qk_reshaped, rotary_pos_emb)
|
||||
qk_rotated = qk_rotated.view(
|
||||
2,
|
||||
batch_size,
|
||||
seq_len,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
q, k = qk_rotated.unbind(dim=0)
|
||||
else:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
|
||||
if self.is_flash_attn_backend:
|
||||
context_layer = vit_flash_attn_wrapper(
|
||||
|
||||
Reference in New Issue
Block a user