[perf] Integrate flashinfer concat_mla_k (#31171)
This commit is contained in:
@@ -1885,6 +1885,16 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.indexer = indexer
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
|
||||
# Use flashinfer's optimized concat_mla_k kernel when available.
|
||||
# The kernel is optimized for DeepSeek V3 dimensions:
|
||||
# num_heads=128, nope_dim=128, rope_dim=64
|
||||
self._use_flashinfer_concat_mla_k = (
|
||||
has_flashinfer()
|
||||
and (self.num_heads == 128)
|
||||
and (self.qk_nope_head_dim == 128)
|
||||
and (self.qk_rope_head_dim == 64)
|
||||
)
|
||||
|
||||
if use_trtllm_ragged_deepseek_prefill():
|
||||
logger.info_once(
|
||||
"Using TRT-LLM ragged DeepSeek prefill for MLA", scope="local"
|
||||
@@ -2192,9 +2202,13 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
dtype=k_nope.dtype,
|
||||
device=k_nope.device,
|
||||
)
|
||||
# Direct copies with efficient broadcasting
|
||||
k[..., : k_nope.shape[-1]] = k_nope
|
||||
k[..., k_nope.shape[-1] :] = k_pe
|
||||
|
||||
if self._use_flashinfer_concat_mla_k:
|
||||
torch.ops.vllm.flashinfer_concat_mla_k(k, k_nope, k_pe)
|
||||
else:
|
||||
# Fallback: Direct copies with efficient broadcasting
|
||||
k[..., : k_nope.shape[-1]] = k_nope
|
||||
k[..., k_nope.shape[-1] :] = k_pe
|
||||
return k
|
||||
|
||||
def _compute_prefill_context(
|
||||
|
||||
Reference in New Issue
Block a user