Implement zero-copy GQA for multimodal and CPU (#33732)
Signed-off-by: Taeksang Kim <ts.kim@hyperaccel.ai>
This commit is contained in:
@@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module):
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
key = torch.repeat_interleave(
|
||||
key,
|
||||
self.num_heads // self.num_kv_heads,
|
||||
dim=2,
|
||||
)
|
||||
value = torch.repeat_interleave(
|
||||
value,
|
||||
self.num_heads // self.num_kv_heads,
|
||||
dim=2,
|
||||
)
|
||||
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
@@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module):
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
is_causal=False,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
).transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
Reference in New Issue
Block a user