[BugFix] Allow qk_nope_head_dim=192 in FlashInfer MLA backend checks (#37475)

Signed-off-by: Kaihang Jiang <kaihangj@nvidia.com>
This commit is contained in:
Kaihang Jiang
2026-03-20 18:14:55 -04:00
committed by GitHub
parent b3d0b37908
commit e5ed6c6c13
2 changed files with 8 additions and 8 deletions

View File

@@ -77,17 +77,17 @@ class FlashInferMLABackend(MLACommonBackend):
use_sparse: bool, use_sparse: bool,
device_capability: DeviceCapability, device_capability: DeviceCapability,
) -> str | None: ) -> str | None:
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128] # FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128, 192]
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None: if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim not in [64, 128]: if qk_nope_head_dim not in [64, 128, 192]:
return ( return (
f"FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128], " "FlashInfer MLA kernel requires qk_nope_head_dim "
f"but got {qk_nope_head_dim}" f"in [64, 128, 192], but got {qk_nope_head_dim}"
) )
return None return None

View File

@@ -113,17 +113,17 @@ class FlashInferMLASparseBackend(AttentionBackend):
use_sparse: bool, use_sparse: bool,
device_capability: DeviceCapability, device_capability: DeviceCapability,
) -> str | None: ) -> str | None:
# FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128 # FlashInfer MLA sparse kernel requires qk_nope_head_dim in [128, 192]
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
if vllm_config.model_config is not None: if vllm_config.model_config is not None:
hf_text_config = vllm_config.model_config.hf_text_config hf_text_config = vllm_config.model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if qk_nope_head_dim != 128: if qk_nope_head_dim not in [128, 192]:
return ( return (
f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, " "FlashInfer MLA Sparse kernel requires qk_nope_head_dim "
f"but got {qk_nope_head_dim}" f"in [128, 192], but got {qk_nope_head_dim}"
) )
# Check for index_topk which indicates sparse model # Check for index_topk which indicates sparse model
if not hasattr(hf_text_config, "index_topk"): if not hasattr(hf_text_config, "index_topk"):