[Test] Only Run MLA model when user explicitly set for batch invariance (#37719)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-22 09:09:12 -04:00
committed by GitHub
parent 77d24c4bfe
commit eaf4978621
3 changed files with 23 additions and 27 deletions

View File

@@ -8,10 +8,10 @@ import pytest
import torch
from utils import (
BACKENDS,
TEST_MODEL,
_extract_step_logprobs,
_random_prompt,
is_device_capability_below_90,
resolve_model_name,
skip_unsupported,
)
@@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend)
model = TEST_MODEL
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
@@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
):
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# For batch invariance, disable custom all-reduce to ensure deterministic
@@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
model=TEST_MODEL,
tensor_parallel_size=tp_size,
max_num_seqs=128,
max_model_len=8192,
@@ -395,7 +394,7 @@ def test_simple_generation(backend):
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
model = resolve_model_name(backend)
model = TEST_MODEL
llm = LLM(
model=model,
@@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail(
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}")
@@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
model=TEST_MODEL,
tensor_parallel_size=tp_size,
max_num_seqs=32,
max_model_len=8192,
@@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs(
"""
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (
@@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs(
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
model=TEST_MODEL,
tensor_parallel_size=tp_size,
max_num_seqs=32,
max_model_len=8192,

View File

@@ -17,7 +17,7 @@ from typing import Any
import openai
import pytest
from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported
from utils import BACKENDS, TEST_MODEL, _random_prompt, skip_unsupported
from tests.utils import RemoteOpenAIServer
@@ -139,7 +139,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
backend: str,
) -> None:
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
model_name = resolve_model_name(backend)
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
sp_kwargs: dict[str, Any] = {
@@ -159,11 +158,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
if tp_size:
server_args += ["-tp", tp_size]
with RemoteOpenAIServer(model_name, server_args) as server:
with RemoteOpenAIServer(TEST_MODEL, server_args) as server:
client = server.get_client()
_compare_bs1_vs_bsn_single_process(
prompts=prompts_all,
sp_kwargs=sp_kwargs,
client=client,
model_name=model_name,
model_name=TEST_MODEL,
)

View File

@@ -7,6 +7,10 @@ import pytest
import torch
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.transformers_utils.model_arch_config_convertor import (
ModelArchConfigConvertorBase,
)
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
skip_unsupported = pytest.mark.skipif(
@@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif(
reason="Requires CUDA and >= Ampere (SM80)",
)
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
TEST_MODEL = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
BACKENDS: list[str] = [
"FLASH_ATTN",
"TRITON_ATTN",
"TRITON_MLA",
]
# FlashInfer temporarily disabled due to invariant CTA sizes.
@@ -27,19 +33,13 @@ BACKENDS: list[str] = [
# if has_flashinfer():
# BACKENDS.append("FLASHINFER")
if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA")
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
def resolve_model_name(backend: str) -> str:
"""Resolve the model name for the given backend."""
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
if backend.endswith("MLA") and model == DEFAULT_MODEL:
return MLA_MODEL
return model
# only run MLA backends when the requested test model is itself an MLA model.
if os.getenv("VLLM_TEST_MODEL"):
config = get_config(TEST_MODEL, trust_remote_code=False)
if ModelArchConfigConvertorBase(config, config.get_text_config()).is_deepseek_mla():
BACKENDS = ["TRITON_MLA"]
if flash_attn_supports_mla():
BACKENDS.append("FLASH_ATTN_MLA")
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: