[perf] Integrate flashinfer concat_mla_k (#31171)

This commit is contained in:
jiahanc
2026-02-05 18:23:11 +08:00
committed by GitHub
parent 8322d4e47f
commit 59a5cb387a
2 changed files with 64 additions and 3 deletions

View File

@@ -1885,6 +1885,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.indexer = indexer self.indexer = indexer
self.q_pad_num_heads = q_pad_num_heads self.q_pad_num_heads = q_pad_num_heads
# Use flashinfer's optimized concat_mla_k kernel when available.
# The kernel is optimized for DeepSeek V3 dimensions:
# num_heads=128, nope_dim=128, rope_dim=64
self._use_flashinfer_concat_mla_k = (
has_flashinfer()
and (self.num_heads == 128)
and (self.qk_nope_head_dim == 128)
and (self.qk_rope_head_dim == 64)
)
if use_trtllm_ragged_deepseek_prefill(): if use_trtllm_ragged_deepseek_prefill():
logger.info_once( logger.info_once(
"Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local" "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
@@ -2192,7 +2202,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
dtype=k_nope.dtype, dtype=k_nope.dtype,
device=k_nope.device, device=k_nope.device,
) )
# Direct copies with efficient broadcasting
if self._use_flashinfer_concat_mla_k:
torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe)
else:
# Fallback: Direct copies with efficient broadcasting
k[..., : k_nope.shape[-1]] = k_nope k[..., : k_nope.shape[-1]] = k_nope
k[..., k_nope.shape[-1] :] = k_pe k[..., k_nope.shape[-1] :] = k_pe
return k return k

View File

@@ -396,6 +396,53 @@ def use_trtllm_attention(
if has_flashinfer(): if has_flashinfer():
from vllm.utils.torch_utils import direct_register_custom_op
def _flashinfer_concat_mla_k(
k: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
) -> None:
"""Custom op wrapper for flashinfer's concat_mla_k.
This is an in-place operation that concatenates k_nope and k_pe into k.
The kernel is optimized for DeepSeek V3 dimensions:
- num_heads=128
- nope_dim=128
- rope_dim=64
Key optimizations:
- Warp-based processing with software pipelining
- Vectorized memory access (int2 for nope, int for rope)
- L2 prefetching for next row while processing current
- Register reuse for rope values across all heads
Args:
k: Output tensor, shape [num_tokens, num_heads, nope_dim + rope_dim].
Modified in-place.
k_nope: The nope part of k, shape [num_tokens, num_heads, nope_dim].
k_pe: The rope part of k (shared), shape [num_tokens, 1, rope_dim].
This is broadcast to all heads.
"""
from flashinfer.concat_ops import concat_mla_k
concat_mla_k(k, k_nope, k_pe)
def _flashinfer_concat_mla_k_fake(
k: torch.Tensor,
k_nope: torch.Tensor,
k_pe: torch.Tensor,
) -> None:
return
# Register flashinfer concat_mla_k custom op
direct_register_custom_op(
op_name="flashinfer_concat_mla_k",
op_func=_flashinfer_concat_mla_k,
mutates_args=["k"], # k tensor is modified in-place
fake_impl=_flashinfer_concat_mla_k_fake,
)
@torch.library.custom_op( @torch.library.custom_op(
"vllm::flashinfer_mm_fp4", "vllm::flashinfer_mm_fp4",