[perf] Integrate flashinfer concat_mla_k (#31171)
This commit is contained in:
@@ -396,6 +396,53 @@ def use_trtllm_attention(
|
||||
|
||||
|
||||
if has_flashinfer():
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
def _flashinfer_concat_mla_k(
|
||||
k: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
) -> None:
|
||||
"""Custom op wrapper for flashinfer's concat_mla_k.
|
||||
|
||||
This is an in-place operation that concatenates k_nope and k_pe into k.
|
||||
|
||||
The kernel is optimized for DeepSeek V3 dimensions:
|
||||
- num_heads=128
|
||||
- nope_dim=128
|
||||
- rope_dim=64
|
||||
|
||||
Key optimizations:
|
||||
- Warp-based processing with software pipelining
|
||||
- Vectorized memory access (int2 for nope, int for rope)
|
||||
- L2 prefetching for next row while processing current
|
||||
- Register reuse for rope values across all heads
|
||||
|
||||
Args:
|
||||
k: Output tensor, shape [num_tokens, num_heads, nope_dim + rope_dim].
|
||||
Modified in-place.
|
||||
k_nope: The nope part of k, shape [num_tokens, num_heads, nope_dim].
|
||||
k_pe: The rope part of k (shared), shape [num_tokens, 1, rope_dim].
|
||||
This is broadcast to all heads.
|
||||
"""
|
||||
from flashinfer.concat_ops import concat_mla_k
|
||||
|
||||
concat_mla_k(k, k_nope, k_pe)
|
||||
|
||||
def _flashinfer_concat_mla_k_fake(
|
||||
k: torch.Tensor,
|
||||
k_nope: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
# Register flashinfer concat_mla_k custom op
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_concat_mla_k",
|
||||
op_func=_flashinfer_concat_mla_k,
|
||||
mutates_args=["k"], # k tensor is modified in-place
|
||||
fake_impl=_flashinfer_concat_mla_k_fake,
|
||||
)
|
||||
|
||||
@torch.library.custom_op(
|
||||
"vllm::flashinfer_mm_fp4",
|
||||
|
||||
Reference in New Issue
Block a user