diff --git a/tests/v1/determinism/test_batch_invariance.py b/tests/v1/determinism/test_batch_invariance.py index d3d0a4a48..11550c190 100644 --- a/tests/v1/determinism/test_batch_invariance.py +++ b/tests/v1/determinism/test_batch_invariance.py @@ -8,10 +8,10 @@ import pytest import torch from utils import ( BACKENDS, + TEST_MODEL, _extract_step_logprobs, _random_prompt, is_device_capability_below_90, - resolve_model_name, skip_unsupported, ) @@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle( attention_config = {"backend": backend} # Allow overrides from environment (useful for CI tuning) # "facebook/opt-125m" is too small, doesn't reliably test determinism - model = resolve_model_name(backend) + model = TEST_MODEL num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5")) max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128")) min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024")) @@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( ): seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) - model_name = resolve_model_name(backend) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) # For batch invariance, disable custom all-reduce to ensure deterministic @@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( print(f"{'=' * 80}\n") llm = LLM( - model=model_name, + model=TEST_MODEL, tensor_parallel_size=tp_size, max_num_seqs=128, max_model_len=8192, @@ -395,7 +394,7 @@ def test_simple_generation(backend): Simple test that runs the model with a basic prompt and prints the output. Useful for quick smoke testing and debugging. """ - model = resolve_model_name(backend) + model = TEST_MODEL llm = LLM( model=model, @@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail( monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False) seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) - model_name = resolve_model_name(backend) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) print(f"\n{'=' * 80}") @@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail( print(f"{'=' * 80}\n") llm = LLM( - model=model_name, + model=TEST_MODEL, tensor_parallel_size=tp_size, max_num_seqs=32, max_model_len=8192, @@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs( """ seed = int(os.getenv("VLLM_TEST_SEED", "12345")) random.seed(seed) - model_name = resolve_model_name(backend) tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1")) from vllm.model_executor.layers.batch_invariant import ( @@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs( print(f"{'=' * 80}\n") llm = LLM( - model=model_name, + model=TEST_MODEL, tensor_parallel_size=tp_size, max_num_seqs=32, max_model_len=8192, diff --git a/tests/v1/determinism/test_online_batch_invariance.py b/tests/v1/determinism/test_online_batch_invariance.py index 52c8103b2..2bebb2dca 100644 --- a/tests/v1/determinism/test_online_batch_invariance.py +++ b/tests/v1/determinism/test_online_batch_invariance.py @@ -17,7 +17,7 @@ from typing import Any import openai import pytest -from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported +from utils import BACKENDS, TEST_MODEL, _random_prompt, skip_unsupported from tests.utils import RemoteOpenAIServer @@ -139,7 +139,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( backend: str, ) -> None: random.seed(int(os.getenv("VLLM_TEST_SEED", "12345"))) - model_name = resolve_model_name(backend) prompts_all = [_random_prompt(10, 50) for _ in range(32)] sp_kwargs: dict[str, Any] = { @@ -159,11 +158,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN( if tp_size: server_args += ["-tp", tp_size] - with RemoteOpenAIServer(model_name, server_args) as server: + with RemoteOpenAIServer(TEST_MODEL, server_args) as server: client = server.get_client() _compare_bs1_vs_bsn_single_process( prompts=prompts_all, sp_kwargs=sp_kwargs, client=client, - model_name=model_name, + model_name=TEST_MODEL, ) diff --git a/tests/v1/determinism/utils.py b/tests/v1/determinism/utils.py index ca3ccab5e..f9bebec98 100644 --- a/tests/v1/determinism/utils.py +++ b/tests/v1/determinism/utils.py @@ -7,6 +7,10 @@ import pytest import torch from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.transformers_utils.model_arch_config_convertor import ( + ModelArchConfigConvertorBase, +) from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla skip_unsupported = pytest.mark.skipif( @@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif( reason="Requires CUDA and >= Ampere (SM80)", ) +DEFAULT_MODEL = "Qwen/Qwen3-1.7B" +TEST_MODEL = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL) + BACKENDS: list[str] = [ "FLASH_ATTN", "TRITON_ATTN", - "TRITON_MLA", ] # FlashInfer temporarily disabled due to invariant CTA sizes. @@ -27,19 +33,13 @@ BACKENDS: list[str] = [ # if has_flashinfer(): # BACKENDS.append("FLASHINFER") -if flash_attn_supports_mla(): - BACKENDS.append("FLASH_ATTN_MLA") - -DEFAULT_MODEL = "Qwen/Qwen3-1.7B" -MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat" - - -def resolve_model_name(backend: str) -> str: - """Resolve the model name for the given backend.""" - model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL) - if backend.endswith("MLA") and model == DEFAULT_MODEL: - return MLA_MODEL - return model +# only run MLA backends when the requested test model is itself an MLA model. +if os.getenv("VLLM_TEST_MODEL"): + config = get_config(TEST_MODEL, trust_remote_code=False) + if ModelArchConfigConvertorBase(config, config.get_text_config()).is_deepseek_mla(): + BACKENDS = ["TRITON_MLA"] + if flash_attn_supports_mla(): + BACKENDS.append("FLASH_ATTN_MLA") def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str: