Implement zero-copy GQA for multimodal and CPU (#33732)
Signed-off-by: Taeksang Kim <ts.kim@hyperaccel.ai>
This commit is contained in:
@@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp):
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_reshape_qkv_to_4d(
|
||||
def view_qkv_to_4d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp):
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if (num_repeat := self.num_queries_per_kv) > 1:
|
||||
# Handle MQA and GQA
|
||||
key = torch.repeat_interleave(key, num_repeat, dim=2)
|
||||
value = torch.repeat_interleave(value, num_repeat, dim=2)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _forward_sdpa(
|
||||
@@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp):
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_torch_sdpa_wrapper(
|
||||
q=query,
|
||||
@@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp):
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
@@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp):
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.maybe_reshape_qkv_to_4d(
|
||||
query, key, value, bsz, q_len, kv_len
|
||||
)
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
|
||||
@@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module):
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
if self.num_heads != self.num_kv_heads:
|
||||
key = torch.repeat_interleave(
|
||||
key,
|
||||
self.num_heads // self.num_kv_heads,
|
||||
dim=2,
|
||||
)
|
||||
value = torch.repeat_interleave(
|
||||
value,
|
||||
self.num_heads // self.num_kv_heads,
|
||||
dim=2,
|
||||
)
|
||||
|
||||
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
@@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module):
|
||||
value,
|
||||
attn_mask=attn_mask,
|
||||
is_causal=False,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
).transpose(1, 2)
|
||||
|
||||
return out.reshape(bsz, q_len, -1)
|
||||
|
||||
@@ -398,10 +398,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
|
||||
|
||||
causal_attn = attn_type == AttentionType.DECODER
|
||||
|
||||
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
|
||||
@@ -418,6 +414,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
|
||||
dropout_p=0.0,
|
||||
is_causal=causal_attn and mask is None,
|
||||
scale=self.scale,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
|
||||
@@ -115,13 +115,16 @@ def apply_sdpa(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Input shape:
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
|
||||
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=scale)
|
||||
output = F.scaled_dot_product_attention(
|
||||
q, k, v, dropout_p=0.0, scale=scale, enable_gqa=enable_gqa
|
||||
)
|
||||
output = einops.rearrange(output, "b h s d -> b s h d ")
|
||||
return output
|
||||
|
||||
@@ -134,6 +137,7 @@ def torch_sdpa_wrapper(
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
enable_gqa: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Never remove the contiguous logic for ROCm
|
||||
# Without it, hallucinations occur with the backend
|
||||
@@ -143,7 +147,7 @@ def torch_sdpa_wrapper(
|
||||
v = v.contiguous()
|
||||
|
||||
if cu_seqlens is None:
|
||||
return apply_sdpa(q, k, v, scale=scale)
|
||||
return apply_sdpa(q, k, v, scale=scale, enable_gqa=enable_gqa)
|
||||
|
||||
outputs = []
|
||||
|
||||
@@ -152,7 +156,7 @@ def torch_sdpa_wrapper(
|
||||
k_chunks = torch.split(k, lens, dim=1)
|
||||
v_chunks = torch.split(v, lens, dim=1)
|
||||
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
|
||||
output_i = apply_sdpa(q_i, k_i, v_i, scale=scale)
|
||||
output_i = apply_sdpa(q_i, k_i, v_i, scale=scale, enable_gqa=enable_gqa)
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
return context_layer
|
||||
@@ -164,6 +168,7 @@ def torch_sdpa_wrapper_fake(
|
||||
v: torch.Tensor,
|
||||
scale: float | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
enable_gqa: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(q)
|
||||
|
||||
@@ -181,5 +186,8 @@ def vit_torch_sdpa_wrapper(
|
||||
v: torch.Tensor,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
enable_gqa: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens)
|
||||
return torch.ops.vllm.torch_sdpa_wrapper(
|
||||
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user