[Attention][Perf][Kernel] Replace torch.cat with vectorized CUDA kernel MLA query concat - DeepSeek-V3.2 (#34917)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2b28b9b269
commit
580864d81e
@@ -568,6 +568,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
)
|
||||
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
q_concat_shape = (max_tokens, num_heads, head_size)
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
assert kv_cache_dtype == "fp8_ds_mla", (
|
||||
"FlashMLA Sparse Attention backend fp8 only supports "
|
||||
@@ -576,17 +579,21 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
|
||||
if kv_cache_dtype == "fp8_ds_mla":
|
||||
# Reserve workspace during initialization
|
||||
vllm_config = get_current_vllm_config()
|
||||
assert vllm_config is not None and vllm_config.model_config is not None
|
||||
prefill_workspace_size = get_prefill_workspace_size(
|
||||
vllm_config.model_config.max_model_len
|
||||
)
|
||||
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
|
||||
(self.prefill_bf16_workspace,) = (
|
||||
self.q_concat_buffer, self.prefill_bf16_workspace = (
|
||||
current_workspace_manager().get_simultaneous(
|
||||
(self.prefill_workspace_shape, torch.bfloat16)
|
||||
(q_concat_shape, torch.bfloat16),
|
||||
(self.prefill_workspace_shape, torch.bfloat16),
|
||||
)
|
||||
)
|
||||
else:
|
||||
(self.q_concat_buffer,) = current_workspace_manager().get_simultaneous(
|
||||
(q_concat_shape, torch.bfloat16),
|
||||
)
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
@@ -828,7 +835,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
|
||||
|
||||
# Concatenate q if it's a tuple (ql_nope, q_pe)
|
||||
if isinstance(q, tuple):
|
||||
q = torch.cat(q, dim=-1)
|
||||
ql_nope, q_pe = q
|
||||
q = self.q_concat_buffer[: ql_nope.shape[0]]
|
||||
ops.concat_mla_q(ql_nope, q_pe, q)
|
||||
|
||||
num_actual_toks = q.shape[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user