diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla.py b/vllm/v1/attention/backends/mla/flashinfer_mla.py index 3de0dcdd8..16d01bd33 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla.py @@ -77,17 +77,17 @@ class FlashInferMLABackend(MLACommonBackend): use_sparse: bool, device_capability: DeviceCapability, ) -> str | None: - # FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128] + # FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128, 192] from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() if vllm_config.model_config is not None: hf_text_config = vllm_config.model_config.hf_text_config qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) - if qk_nope_head_dim not in [64, 128]: + if qk_nope_head_dim not in [64, 128, 192]: return ( - f"FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128], " - f"but got {qk_nope_head_dim}" + "FlashInfer MLA kernel requires qk_nope_head_dim " + f"in [64, 128, 192], but got {qk_nope_head_dim}" ) return None diff --git a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py index 9554457b4..7b5ec0d49 100644 --- a/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py @@ -113,17 +113,17 @@ class FlashInferMLASparseBackend(AttentionBackend): use_sparse: bool, device_capability: DeviceCapability, ) -> str | None: - # FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128 + # FlashInfer MLA sparse kernel requires qk_nope_head_dim in [128, 192] from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() if vllm_config.model_config is not None: hf_text_config = vllm_config.model_config.hf_text_config qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) - if qk_nope_head_dim != 128: + if qk_nope_head_dim not in [128, 192]: return ( - f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, " - f"but got {qk_nope_head_dim}" + "FlashInfer MLA Sparse kernel requires qk_nope_head_dim " + f"in [128, 192], but got {qk_nope_head_dim}" ) # Check for index_topk which indicates sparse model if not hasattr(hf_text_config, "index_topk"):