Implement zero-copy GQA for multimodal and CPU (#33732)

Signed-off-by: Taeksang Kim <ts.kim@hyperaccel.ai>
This commit is contained in:
Taeksang Kim
2026-02-05 05:11:39 +09:00
committed by GitHub
parent 2f6d17cb2f
commit 6e98f6d8b6
4 changed files with 18 additions and 32 deletions

View File

@@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp):
def enabled(cls) -> bool:
return True
def maybe_reshape_qkv_to_4d(
def view_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
@@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp):
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
return query, key, value
def _forward_sdpa(
@@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp):
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_torch_sdpa_wrapper(
q=query,
@@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp):
v=value,
scale=self.scale,
cu_seqlens=cu_seqlens,
enable_gqa=self.num_heads > self.num_kv_heads,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
@@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp):
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d(
query, key, value, bsz, q_len, kv_len
)
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_flash_attn_wrapper(
q=query,