Implement zero-copy GQA for multimodal and CPU (#33732)
Signed-off-by: Taeksang Kim <ts.kim@hyperaccel.ai>
This commit is contained in:
@@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp):
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_reshape_qkv_to_4d(
|
||||
def view_qkv_to_4d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp):
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _forward_sdpa(
|
||||
@@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp):
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_torch_sdpa_wrapper(
|
||||
q=query,
|
||||
@@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp):
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
@@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp):
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
|
||||
Reference in New Issue
Block a user