[Test] Only Run MLA model when user explicitly set for batch invariance (#37719)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye
2026-03-22 09:09:12 -04:00
committed by GitHub
parent 77d24c4bfe
commit eaf4978621
3 changed files with 23 additions and 27 deletions

View File

@@ -8,10 +8,10 @@ import pytest
import torch
from utils import (
BACKENDS,
TEST_MODEL,
_extract_step_logprobs,
_random_prompt,
is_device_capability_below_90,
resolve_model_name,
skip_unsupported,
)
@@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
attention_config = {"backend": backend}
# Allow overrides from environment (useful for CI tuning)
# "facebook/opt-125m" is too small, doesn't reliably test determinism
model = resolve_model_name(backend)
model = TEST_MODEL
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
@@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
):
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
# For batch invariance, disable custom all-reduce to ensure deterministic
@@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
model=TEST_MODEL,
tensor_parallel_size=tp_size,
max_num_seqs=128,
max_model_len=8192,
@@ -395,7 +394,7 @@ def test_simple_generation(backend):
Simple test that runs the model with a basic prompt and prints the output.
Useful for quick smoke testing and debugging.
"""
model = resolve_model_name(backend)
model = TEST_MODEL
llm = LLM(
model=model,
@@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail(
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
print(f"\n{'=' * 80}")
@@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
model=TEST_MODEL,
tensor_parallel_size=tp_size,
max_num_seqs=32,
max_model_len=8192,
@@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs(
"""
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
random.seed(seed)
model_name = resolve_model_name(backend)
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
from vllm.model_executor.layers.batch_invariant import (
@@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs(
print(f"{'=' * 80}\n")
llm = LLM(
model=model_name,
model=TEST_MODEL,
tensor_parallel_size=tp_size,
max_num_seqs=32,
max_model_len=8192,