fix: Add glm4_moe_lite to MLA detection (#32614)
Signed-off-by: marksverdhei <marksverdhei@hotmail.com> Signed-off-by: Markus / Mark <46672778+marksverdhei@users.noreply.github.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -440,18 +440,34 @@ M = TypeVar("M", bound=MLACommonMetadata)
|
||||
A = TypeVar("A", bound=AttentionMetadata)
|
||||
|
||||
|
||||
def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool:
|
||||
# Check if model has DeepSeek R1 compatible MLA dimensions:
|
||||
# qk_nope_head_dim = 128, qk_rope_head_dim = 64, v_head_dim = 128
|
||||
# which results in query/key head dim = 192.
|
||||
if vllm_config.model_config is None:
|
||||
return False
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
|
||||
qk_rope_head_dim = getattr(hf_text_config, "qk_rope_head_dim", 1)
|
||||
v_head_dim = getattr(hf_text_config, "v_head_dim", 1)
|
||||
return qk_nope_head_dim == 128 and qk_rope_head_dim == 64 and v_head_dim == 128
|
||||
|
||||
|
||||
def use_flashinfer_prefill() -> bool:
|
||||
# For blackwell default to flashinfer prefill if it's available since
|
||||
# it is faster than FA2.
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
if not (
|
||||
not vllm_config.attention_config.disable_flashinfer_prefill
|
||||
and flashinfer_available
|
||||
and not vllm_config.attention_config.use_cudnn_prefill
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
return is_deepseek_r1_mla_compatible(vllm_config)
|
||||
|
||||
|
||||
def use_cudnn_prefill() -> bool:
|
||||
@@ -471,11 +487,14 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
if not (
|
||||
flashinfer_available
|
||||
and vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
):
|
||||
return False
|
||||
|
||||
return is_deepseek_r1_mla_compatible(vllm_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user