diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index febad3821..862f84939 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1885,6 +1885,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): self.indexer = indexer self.q_pad_num_heads = q_pad_num_heads + # Use flashinfer's optimized concat_mla_k kernel when available. + # The kernel is optimized for DeepSeek V3 dimensions: + # num_heads=128, nope_dim=128, rope_dim=64 + self._use_flashinfer_concat_mla_k = ( + has_flashinfer() + and (self.num_heads == 128) + and (self.qk_nope_head_dim == 128) + and (self.qk_rope_head_dim == 64) + ) + if use_trtllm_ragged_deepseek_prefill(): logger.info_once( "Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local" @@ -2192,9 +2202,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): dtype=k_nope.dtype, device=k_nope.device, ) - # Direct copies with efficient broadcasting - k[..., : k_nope.shape[-1]] = k_nope - k[..., k_nope.shape[-1] :] = k_pe + + if self._use_flashinfer_concat_mla_k: + torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe) + else: + # Fallback: Direct copies with efficient broadcasting + k[..., : k_nope.shape[-1]] = k_nope + k[..., k_nope.shape[-1] :] = k_pe return k def _compute_prefill_context( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index f8cb1e14e..88e31718a 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -396,6 +396,53 @@ def use_trtllm_attention( if has_flashinfer(): + from vllm.utils.torch_utils import direct_register_custom_op + + def _flashinfer_concat_mla_k( + k: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + ) -> None: + """Custom op wrapper for flashinfer's concat_mla_k. + + This is an in-place operation that concatenates k_nope and k_pe into k. + + The kernel is optimized for DeepSeek V3 dimensions: + - num_heads=128 + - nope_dim=128 + - rope_dim=64 + + Key optimizations: + - Warp-based processing with software pipelining + - Vectorized memory access (int2 for nope, int for rope) + - L2 prefetching for next row while processing current + - Register reuse for rope values across all heads + + Args: + k: Output tensor, shape [num_tokens, num_heads, nope_dim + rope_dim]. + Modified in-place. + k_nope: The nope part of k, shape [num_tokens, num_heads, nope_dim]. + k_pe: The rope part of k (shared), shape [num_tokens, 1, rope_dim]. + This is broadcast to all heads. + """ + from flashinfer.concat_ops import concat_mla_k + + concat_mla_k(k, k_nope, k_pe) + + def _flashinfer_concat_mla_k_fake( + k: torch.Tensor, + k_nope: torch.Tensor, + k_pe: torch.Tensor, + ) -> None: + return + + # Register flashinfer concat_mla_k custom op + direct_register_custom_op( + op_name="flashinfer_concat_mla_k", + op_func=_flashinfer_concat_mla_k, + mutates_args=["k"], # k tensor is modified in-place + fake_impl=_flashinfer_concat_mla_k_fake, + ) @torch.library.custom_op( "vllm::flashinfer_mm_fp4",