[Attention][Perf][Kernel] Replace torch.cat with vectorized CUDA kernel MLA query concat - DeepSeek-V3.2 (#34917)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
This commit is contained in:
Roberto L. Castro
2026-03-09 17:50:36 +01:00
committed by GitHub
parent 2b28b9b269
commit 580864d81e
10 changed files with 415 additions and 15 deletions

View File

@@ -568,6 +568,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)
vllm_config = get_current_vllm_config()
max_tokens = vllm_config.scheduler_config.max_num_batched_tokens
q_concat_shape = (max_tokens, num_heads, head_size)
if kv_cache_dtype.startswith("fp8"):
assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports "
@@ -576,17 +579,21 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()
assert vllm_config is not None and vllm_config.model_config is not None
prefill_workspace_size = get_prefill_workspace_size(
vllm_config.model_config.max_model_len
)
self.prefill_workspace_shape = (prefill_workspace_size, head_size)
(self.prefill_bf16_workspace,) = (
self.q_concat_buffer, self.prefill_bf16_workspace = (
current_workspace_manager().get_simultaneous(
(self.prefill_workspace_shape, torch.bfloat16)
(q_concat_shape, torch.bfloat16),
(self.prefill_workspace_shape, torch.bfloat16),
)
)
else:
(self.q_concat_buffer,) = current_workspace_manager().get_simultaneous(
(q_concat_shape, torch.bfloat16),
)
def _forward_bf16_kv(
self,
@@ -828,7 +835,9 @@ class FlashMLASparseImpl(SparseMLAAttentionImpl[FlashMLASparseMetadata]):
# Concatenate q if it's a tuple (ql_nope, q_pe)
if isinstance(q, tuple):
q = torch.cat(q, dim=-1)
ql_nope, q_pe = q
q = self.q_concat_buffer[: ql_nope.shape[0]]
ops.concat_mla_q(ql_nope, q_pe, q)
num_actual_toks = q.shape[0]