[Attention] Update tests to remove deprecated env vars (#30563)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-12-17 12:49:59 -05:00
committed by GitHub
parent 9ca8cb38fd
commit 7eb6cb6c18
34 changed files with 580 additions and 447 deletions

View File

@@ -142,16 +142,17 @@ def run_tests(
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor with spec decoding."""
with monkeypatch.context() as m:
# avoid precision errors
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
else:
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
# Determine attention config based on platform
if current_platform.is_rocm():
if is_testing_with_spec_decoding:
# Use TRITON_ATTN for spec decoding test for consistency
attention_config = {"backend": "TRITON_ATTN"}
else:
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
attention_config = {"backend": "ROCM_AITER_FA"}
else:
attention_config = {"backend": "FLEX_ATTENTION"}
with monkeypatch.context() as m:
# lock matmul precision to full FP32 (IEEE)
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
# m.setenv("VLLM_BATCH_INVARIANT", "1")
@@ -174,6 +175,7 @@ def run_tests(
spec_config,
test_prefill_chunking=test_prefill_chunking,
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
attention_config=attention_config,
)
outputs.append(test_results)
@@ -262,6 +264,7 @@ def run_test(
spec_config: dict[str, Any] | None,
test_prefill_chunking: bool,
is_testing_with_spec_decoding: bool = False,
attention_config: dict[str, Any] | None = None,
):
spec_decoding = spec_config is not None
cache_arg: dict[str, Any] = (
@@ -301,6 +304,7 @@ def run_test(
dtype=dtype,
speculative_config=spec_config,
disable_log_stats=False,
attention_config=attention_config,
**cache_arg,
) as vllm_model:
results = []

View File

@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
@create_new_process_for_each_test()
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
def test_cascade_attention(example_system_message, attn_backend):
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
if attn_backend == "FLASHINFER":
@@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
"needs investigation. See issue #25679."
)
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
llm = LLM(
model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
# No cascade attention.
single_prompt = [example_system_message + prompt]
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text
# No cascade attention.
single_prompt = [example_system_message + prompt]
responses = llm.generate(single_prompt, sampling_params)
ref_output = responses[0].outputs[0].text
# (Probably) Use cascade attention.
prompts = [example_system_message + prompt] * 64
responses = llm.generate(prompts, sampling_params)
for response in responses:
assert response.outputs[0].text == ref_output
# (Probably) Use cascade attention.
prompts = [example_system_message + prompt] * 64
responses = llm.generate(prompts, sampling_params)
for response in responses:
assert response.outputs[0].text == ref_output

View File

@@ -438,25 +438,26 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
"""
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
attention_config = None # Let it fall back to default
else:
attention_config = {"backend": attn_backend}
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
with monkeypatch.context() as m:
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
m.setenv("VLLM_MLA_DISABLE", "1")
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower():
@@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
@@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
attention_config=attention_config,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0