Implement zero-copy GQA for multimodal and CPU (#33732)

Signed-off-by: Taeksang Kim <ts.kim@hyperaccel.ai>
This commit is contained in:
Taeksang Kim
2026-02-05 05:11:39 +09:00
committed by GitHub
parent 2f6d17cb2f
commit 6e98f6d8b6
4 changed files with 18 additions and 32 deletions

View File

@@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module):
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
if self.num_heads != self.num_kv_heads:
key = torch.repeat_interleave(
key,
self.num_heads // self.num_kv_heads,
dim=2,
)
value = torch.repeat_interleave(
value,
self.num_heads // self.num_kv_heads,
dim=2,
)
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention(
@@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module):
value,
attn_mask=attn_mask,
is_causal=False,
enable_gqa=self.num_heads > self.num_kv_heads,
).transpose(1, 2)
return out.reshape(bsz, q_len, -1)