[BugFix] Allow qk_nope_head_dim=192 in FlashInfer MLA backend checks (#37475)
Signed-off-by: Kaihang Jiang <kaihangj@nvidia.com>
This commit is contained in:
@@ -77,17 +77,17 @@ class FlashInferMLABackend(MLACommonBackend):
|
|||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
device_capability: DeviceCapability,
|
device_capability: DeviceCapability,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128]
|
# FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128, 192]
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
if vllm_config.model_config is not None:
|
if vllm_config.model_config is not None:
|
||||||
hf_text_config = vllm_config.model_config.hf_text_config
|
hf_text_config = vllm_config.model_config.hf_text_config
|
||||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||||
if qk_nope_head_dim not in [64, 128]:
|
if qk_nope_head_dim not in [64, 128, 192]:
|
||||||
return (
|
return (
|
||||||
f"FlashInfer MLA kernel requires qk_nope_head_dim in [64, 128], "
|
"FlashInfer MLA kernel requires qk_nope_head_dim "
|
||||||
f"but got {qk_nope_head_dim}"
|
f"in [64, 128, 192], but got {qk_nope_head_dim}"
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@@ -113,17 +113,17 @@ class FlashInferMLASparseBackend(AttentionBackend):
|
|||||||
use_sparse: bool,
|
use_sparse: bool,
|
||||||
device_capability: DeviceCapability,
|
device_capability: DeviceCapability,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
# FlashInfer MLA sparse kernel requires qk_nope_head_dim == 128
|
# FlashInfer MLA sparse kernel requires qk_nope_head_dim in [128, 192]
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
if vllm_config.model_config is not None:
|
if vllm_config.model_config is not None:
|
||||||
hf_text_config = vllm_config.model_config.hf_text_config
|
hf_text_config = vllm_config.model_config.hf_text_config
|
||||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||||
if qk_nope_head_dim != 128:
|
if qk_nope_head_dim not in [128, 192]:
|
||||||
return (
|
return (
|
||||||
f"FlashInfer MLA Sparse kernel requires qk_nope_head_dim == 128, "
|
"FlashInfer MLA Sparse kernel requires qk_nope_head_dim "
|
||||||
f"but got {qk_nope_head_dim}"
|
f"in [128, 192], but got {qk_nope_head_dim}"
|
||||||
)
|
)
|
||||||
# Check for index_topk which indicates sparse model
|
# Check for index_topk which indicates sparse model
|
||||||
if not hasattr(hf_text_config, "index_topk"):
|
if not hasattr(hf_text_config, "index_topk"):
|
||||||
|
|||||||
Reference in New Issue
Block a user