[perf] Integrate flashinfer concat_mla_k (#31171)

This commit is contained in:
jiahanc
2026-02-05 18:23:11 +08:00
committed by GitHub
parent 8322d4e47f
commit 59a5cb387a
2 changed files with 64 additions and 3 deletions

View File

@@ -1885,6 +1885,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads
# Use flashinfer's optimized concat_mla_k kernel when available.
# The kernel is optimized for DeepSeek V3 dimensions:
# num_heads=128, nope_dim=128, rope_dim=64
self._use_flashinfer_concat_mla_k = (
has_flashinfer()
and (self.num_heads == 128)
and (self.qk_nope_head_dim == 128)
and (self.qk_rope_head_dim == 64)
)
if use_trtllm_ragged_deepseek_prefill():
logger.info_once(
"Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
@@ -2192,9 +2202,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
dtype=k_nope.dtype,
device=k_nope.device,
)
# Direct copies with efficient broadcasting
k[..., : k_nope.shape[-1]] = k_nope
k[..., k_nope.shape[-1] :] = k_pe
if self._use_flashinfer_concat_mla_k:
torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe)
else:
# Fallback: Direct copies with efficient broadcasting
k[..., : k_nope.shape[-1]] = k_nope
k[..., k_nope.shape[-1] :] = k_pe
return k
def _compute_prefill_context(