[perf] Integrate flashinfer concat_mla_k (#31171)
This commit is contained in:
@@ -1885,6 +1885,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
self.indexer = indexer
|
self.indexer = indexer
|
||||||
self.q_pad_num_heads = q_pad_num_heads
|
self.q_pad_num_heads = q_pad_num_heads
|
||||||
|
|
||||||
|
# Use flashinfer's optimized concat_mla_k kernel when available.
|
||||||
|
# The kernel is optimized for DeepSeek V3 dimensions:
|
||||||
|
# num_heads=128, nope_dim=128, rope_dim=64
|
||||||
|
self._use_flashinfer_concat_mla_k = (
|
||||||
|
has_flashinfer()
|
||||||
|
and (self.num_heads == 128)
|
||||||
|
and (self.qk_nope_head_dim == 128)
|
||||||
|
and (self.qk_rope_head_dim == 64)
|
||||||
|
)
|
||||||
|
|
||||||
if use_trtllm_ragged_deepseek_prefill():
|
if use_trtllm_ragged_deepseek_prefill():
|
||||||
logger.info_once(
|
logger.info_once(
|
||||||
"Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
|
"Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
|
||||||
@@ -2192,7 +2202,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
dtype=k_nope.dtype,
|
dtype=k_nope.dtype,
|
||||||
device=k_nope.device,
|
device=k_nope.device,
|
||||||
)
|
)
|
||||||
# Direct copies with efficient broadcasting
|
|
||||||
|
if self._use_flashinfer_concat_mla_k:
|
||||||
|
torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe)
|
||||||
|
else:
|
||||||
|
# Fallback: Direct copies with efficient broadcasting
|
||||||
k[..., : k_nope.shape[-1]] = k_nope
|
k[..., : k_nope.shape[-1]] = k_nope
|
||||||
k[..., k_nope.shape[-1] :] = k_pe
|
k[..., k_nope.shape[-1] :] = k_pe
|
||||||
return k
|
return k
|
||||||
|
|||||||
@@ -396,6 +396,53 @@ def use_trtllm_attention(
|
|||||||
|
|
||||||
|
|
||||||
if has_flashinfer():
|
if has_flashinfer():
|
||||||
|
from vllm.utils.torch_utils import direct_register_custom_op
|
||||||
|
|
||||||
|
def _flashinfer_concat_mla_k(
|
||||||
|
k: torch.Tensor,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
"""Custom op wrapper for flashinfer's concat_mla_k.
|
||||||
|
|
||||||
|
This is an in-place operation that concatenates k_nope and k_pe into k.
|
||||||
|
|
||||||
|
The kernel is optimized for DeepSeek V3 dimensions:
|
||||||
|
- num_heads=128
|
||||||
|
- nope_dim=128
|
||||||
|
- rope_dim=64
|
||||||
|
|
||||||
|
Key optimizations:
|
||||||
|
- Warp-based processing with software pipelining
|
||||||
|
- Vectorized memory access (int2 for nope, int for rope)
|
||||||
|
- L2 prefetching for next row while processing current
|
||||||
|
- Register reuse for rope values across all heads
|
||||||
|
|
||||||
|
Args:
|
||||||
|
k: Output tensor, shape [num_tokens, num_heads, nope_dim + rope_dim].
|
||||||
|
Modified in-place.
|
||||||
|
k_nope: The nope part of k, shape [num_tokens, num_heads, nope_dim].
|
||||||
|
k_pe: The rope part of k (shared), shape [num_tokens, 1, rope_dim].
|
||||||
|
This is broadcast to all heads.
|
||||||
|
"""
|
||||||
|
from flashinfer.concat_ops import concat_mla_k
|
||||||
|
|
||||||
|
concat_mla_k(k, k_nope, k_pe)
|
||||||
|
|
||||||
|
def _flashinfer_concat_mla_k_fake(
|
||||||
|
k: torch.Tensor,
|
||||||
|
k_nope: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Register flashinfer concat_mla_k custom op
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="flashinfer_concat_mla_k",
|
||||||
|
op_func=_flashinfer_concat_mla_k,
|
||||||
|
mutates_args=["k"], # k tensor is modified in-place
|
||||||
|
fake_impl=_flashinfer_concat_mla_k_fake,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.library.custom_op(
|
@torch.library.custom_op(
|
||||||
"vllm::flashinfer_mm_fp4",
|
"vllm::flashinfer_mm_fp4",
|
||||||
|
|||||||
Reference in New Issue
Block a user