[Test] Only Run MLA model when user explicitly set for batch invariance (#37719)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -7,6 +7,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.transformers_utils.model_arch_config_convertor import (
|
||||
ModelArchConfigConvertorBase,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
@@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif(
|
||||
reason="Requires CUDA and >= Ampere (SM80)",
|
||||
)
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||
TEST_MODEL = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
"FLASH_ATTN",
|
||||
"TRITON_ATTN",
|
||||
"TRITON_MLA",
|
||||
]
|
||||
|
||||
# FlashInfer temporarily disabled due to invariant CTA sizes.
|
||||
@@ -27,19 +33,13 @@ BACKENDS: list[str] = [
|
||||
# if has_flashinfer():
|
||||
# BACKENDS.append("FLASHINFER")
|
||||
|
||||
if flash_attn_supports_mla():
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
|
||||
|
||||
def resolve_model_name(backend: str) -> str:
|
||||
"""Resolve the model name for the given backend."""
|
||||
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
||||
return MLA_MODEL
|
||||
return model
|
||||
# only run MLA backends when the requested test model is itself an MLA model.
|
||||
if os.getenv("VLLM_TEST_MODEL"):
|
||||
config = get_config(TEST_MODEL, trust_remote_code=False)
|
||||
if ModelArchConfigConvertorBase(config, config.get_text_config()).is_deepseek_mla():
|
||||
BACKENDS = ["TRITON_MLA"]
|
||||
if flash_attn_supports_mla():
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
|
||||
Reference in New Issue
Block a user