fix: Add glm4_moe_lite to MLA detection (#32614)
Signed-off-by: marksverdhei <marksverdhei@hotmail.com> Signed-off-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -440,18 +440,34 @@ M = TypeVar("M", bound=MLACommonMetadata)
|
||||
A = TypeVar("A", bound=AttentionMetadata)
|
||||
|
||||
|
||||
def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
|
||||
# Check if model has DeepSeek R1 compatible MLA dimensions:
|
||||
# qk_nope_head_dim = 128, qk_rope_head_dim = 64, v_head_dim = 128
|
||||
# which results in query/key head dim = 192.
|
||||
if vllm_config.model_config is None:
|
||||
return False
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
qk_rope_head_dim = getattr(hf_text_config, "qk_rope_head_dim", 1)
|
||||
v_head_dim = getattr(hf_text_config, "v_head_dim", 1)
|
||||
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
|
||||
|
||||
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
# For blackwell default to flashinfer prefill if it's available since
|
||||
# it is faster than FA2.
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
if not (
|
||||
not vllm_config.attention_config.disable_flashinfer_prefill
|
||||
and flashinfer_available
|
||||
and not vllm_config.attention_config.use_cudnn_prefill
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
return is_deepseek_r1_mla_compatible(vllm_config)
|
||||
|
||||
|
||||
def use_cudnn_prefill() -> bool:
|
||||
@@ -471,11 +487,14 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
if not (
|
||||
flashinfer_available
|
||||
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
return is_deepseek_r1_mla_compatible(vllm_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -180,19 +180,34 @@ class CudaPlatformBase(Platform):
|
||||
use_cutlass_mla = False
|
||||
use_flashinfer_mla = False
|
||||
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
if vllm_config.attention_config.backend is None:
|
||||
# Default case
|
||||
if cls.is_device_capability_family(100) and not use_sparse:
|
||||
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2).
|
||||
hf_text_config = model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
if (
|
||||
cls.is_device_capability_family(100)
|
||||
and not use_sparse
|
||||
and qk_nope_head_dim == 128
|
||||
):
|
||||
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
|
||||
# and only if qk_nope_head_dim == 128 (kernel constraint)
|
||||
use_flashinfer_mla = True
|
||||
# Set the backend in AttentionConfig so it's used during
|
||||
# backend selection
|
||||
vllm_config.attention_config.backend = (
|
||||
AttentionBackendEnum.FLASHINFER_MLA
|
||||
)
|
||||
else:
|
||||
# Not Blackwell
|
||||
elif cls.is_device_capability_family(100) and not use_sparse:
|
||||
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
|
||||
use_cutlass_mla = True
|
||||
elif is_flashmla_dense_supported()[0]:
|
||||
# Non-Blackwell with FlashMLA support
|
||||
use_flashmla = True
|
||||
else:
|
||||
# Fallback: will use Triton MLA or other compatible backend
|
||||
pass
|
||||
else:
|
||||
# Forced case
|
||||
backend = vllm_config.attention_config.backend
|
||||
@@ -200,8 +215,6 @@ class CudaPlatformBase(Platform):
|
||||
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
|
||||
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
|
||||
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
if (
|
||||
use_flashmla
|
||||
and is_flashmla_dense_supported()[0]
|
||||
|
||||
@@ -189,6 +189,8 @@ class ModelArchConfigConvertorBase:
|
||||
"deepseek_v3",
|
||||
"deepseek_v32",
|
||||
"deepseek_mtp",
|
||||
"glm4_moe_lite",
|
||||
"glm4_moe_lite_mtp",
|
||||
"kimi_k2",
|
||||
"kimi_linear",
|
||||
"longcat_flash",
|
||||
|
||||
@@ -63,6 +63,32 @@ class FlashInferMLABackend(MLACommonBackend):
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
# FlashInfer MLA kernel requires qk_nope_head_dim == 128
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if vllm_config.model_config is not None:
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
if qk_nope_head_dim != 128:
|
||||
return (
|
||||
f"FlashInfer MLA kernel requires qk_nope_head_dim == 128, "
|
||||
f"but got {qk_nope_head_dim}"
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
Reference in New Issue
Block a user