[Feat][Bugfix] Enable additional dimension for Flashinfer MLA and fix routing dtype (#36931)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
This commit is contained in:
Dimitrios Bariamis
2026-03-14 00:41:16 +01:00
committed by GitHub
parent 6d53efd2a5
commit 367cf5cd3e
2 changed files with 18 additions and 5 deletions

View File

@@ -75,16 +75,16 @@ class FlashInferMLABackend(MLACommonBackend):
use_sparse: bool,
device_capability: DeviceCapability,
) -> str | None:
# FlashInfer MLA kernel requires qk_nope_head_dim == 128
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128]
from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim != 128:
if qk_nope_head_dim not in [64, 128]:
return (
f"FlashInfer MLA kernel requires qk_nope_head_dim == 128, "
f"FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128], "
f"but got {qk_nope_head_dim}"
)
return None