[Test] Only Run MLA model when user explicitly set for batch invariance (#37719)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
@@ -8,10 +8,10 @@ import pytest
|
||||
import torch
|
||||
from utils import (
|
||||
BACKENDS,
|
||||
TEST_MODEL,
|
||||
_extract_step_logprobs,
|
||||
_random_prompt,
|
||||
is_device_capability_below_90,
|
||||
resolve_model_name,
|
||||
skip_unsupported,
|
||||
)
|
||||
|
||||
@@ -57,7 +57,7 @@ def test_v1_generation_is_deterministic_across_batch_sizes_with_needle(
|
||||
attention_config = {"backend": backend}
|
||||
# Allow overrides from environment (useful for CI tuning)
|
||||
# "facebook/opt-125m" is too small, doesn't reliably test determinism
|
||||
model = resolve_model_name(backend)
|
||||
model = TEST_MODEL
|
||||
num_trials = int(os.getenv("VLLM_NEEDLE_TRIALS", "5"))
|
||||
max_batch_size = int(os.getenv("VLLM_NEEDLE_BATCH_SIZE", "128"))
|
||||
min_random_prompt = int(os.getenv("VLLM_MIN_PROMPT", "1024"))
|
||||
@@ -169,7 +169,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
):
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = resolve_model_name(backend)
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
# For batch invariance, disable custom all-reduce to ensure deterministic
|
||||
@@ -186,7 +185,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
model=TEST_MODEL,
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=128,
|
||||
max_model_len=8192,
|
||||
@@ -395,7 +394,7 @@ def test_simple_generation(backend):
|
||||
Simple test that runs the model with a basic prompt and prints the output.
|
||||
Useful for quick smoke testing and debugging.
|
||||
"""
|
||||
model = resolve_model_name(backend)
|
||||
model = TEST_MODEL
|
||||
|
||||
llm = LLM(
|
||||
model=model,
|
||||
@@ -458,7 +457,6 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
monkeypatch.setattr(batch_invariant, "VLLM_BATCH_INVARIANT", False)
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = resolve_model_name(backend)
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
@@ -466,7 +464,7 @@ def test_logprobs_without_batch_invariance_should_fail(
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
model=TEST_MODEL,
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
@@ -674,7 +672,6 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
"""
|
||||
seed = int(os.getenv("VLLM_TEST_SEED", "12345"))
|
||||
random.seed(seed)
|
||||
model_name = resolve_model_name(backend)
|
||||
tp_size = int(os.getenv("VLLM_TEST_TP_SIZE", "1"))
|
||||
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
@@ -689,7 +686,7 @@ def test_decode_logprobs_match_prefill_logprobs(
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
model=TEST_MODEL,
|
||||
tensor_parallel_size=tp_size,
|
||||
max_num_seqs=32,
|
||||
max_model_len=8192,
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from utils import BACKENDS, _random_prompt, resolve_model_name, skip_unsupported
|
||||
from utils import BACKENDS, TEST_MODEL, _random_prompt, skip_unsupported
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
@@ -139,7 +139,6 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
backend: str,
|
||||
) -> None:
|
||||
random.seed(int(os.getenv("VLLM_TEST_SEED", "12345")))
|
||||
model_name = resolve_model_name(backend)
|
||||
prompts_all = [_random_prompt(10, 50) for _ in range(32)]
|
||||
|
||||
sp_kwargs: dict[str, Any] = {
|
||||
@@ -159,11 +158,11 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(
|
||||
if tp_size:
|
||||
server_args += ["-tp", tp_size]
|
||||
|
||||
with RemoteOpenAIServer(model_name, server_args) as server:
|
||||
with RemoteOpenAIServer(TEST_MODEL, server_args) as server:
|
||||
client = server.get_client()
|
||||
_compare_bs1_vs_bsn_single_process(
|
||||
prompts=prompts_all,
|
||||
sp_kwargs=sp_kwargs,
|
||||
client=client,
|
||||
model_name=model_name,
|
||||
model_name=TEST_MODEL,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.transformers_utils.model_arch_config_convertor import (
|
||||
ModelArchConfigConvertorBase,
|
||||
)
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
||||
|
||||
skip_unsupported = pytest.mark.skipif(
|
||||
@@ -16,10 +20,12 @@ skip_unsupported = pytest.mark.skipif(
|
||||
reason="Requires CUDA and >= Ampere (SM80)",
|
||||
)
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||
TEST_MODEL = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||
|
||||
BACKENDS: list[str] = [
|
||||
"FLASH_ATTN",
|
||||
"TRITON_ATTN",
|
||||
"TRITON_MLA",
|
||||
]
|
||||
|
||||
# FlashInfer temporarily disabled due to invariant CTA sizes.
|
||||
@@ -27,19 +33,13 @@ BACKENDS: list[str] = [
|
||||
# if has_flashinfer():
|
||||
# BACKENDS.append("FLASHINFER")
|
||||
|
||||
if flash_attn_supports_mla():
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
|
||||
DEFAULT_MODEL = "Qwen/Qwen3-1.7B"
|
||||
MLA_MODEL = "deepseek-ai/DeepSeek-V2-Lite-Chat"
|
||||
|
||||
|
||||
def resolve_model_name(backend: str) -> str:
|
||||
"""Resolve the model name for the given backend."""
|
||||
model = os.getenv("VLLM_TEST_MODEL", DEFAULT_MODEL)
|
||||
if backend.endswith("MLA") and model == DEFAULT_MODEL:
|
||||
return MLA_MODEL
|
||||
return model
|
||||
# only run MLA backends when the requested test model is itself an MLA model.
|
||||
if os.getenv("VLLM_TEST_MODEL"):
|
||||
config = get_config(TEST_MODEL, trust_remote_code=False)
|
||||
if ModelArchConfigConvertorBase(config, config.get_text_config()).is_deepseek_mla():
|
||||
BACKENDS = ["TRITON_MLA"]
|
||||
if flash_attn_supports_mla():
|
||||
BACKENDS.append("FLASH_ATTN_MLA")
|
||||
|
||||
|
||||
def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
|
||||
|
||||
Reference in New Issue
Block a user