[Attention] Update tests to remove deprecated env vars (#30563)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-12-17 12:49:59 -05:00
committed by GitHub
parent 9ca8cb38fd
commit 7eb6cb6c18
34 changed files with 580 additions and 447 deletions

View File

@@ -438,25 +438,26 @@ def test_eagle_correctness(
should be the same when using eagle speculative decoding.
model_setup: (method, model_name, eagle_model_name, tp_size)
"""
# Determine attention config
# Scout requires default backend selection because vision encoder has
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
# to Flex Attn
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
attention_config = None # Let it fall back to default
else:
attention_config = {"backend": attn_backend}
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
with monkeypatch.context() as m:
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
# Scout requires default backend selection
# because vision encoder has head_dim 88 being incompatible
# with FLASH_ATTN and needs to fall back to Flex Attn
# pass if not ROCm
if current_platform.is_rocm():
# TODO: Enable Flex Attn for spec_decode on ROCm
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
else:
m.setenv("VLLM_MLA_DISABLE", "1")
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
"multi-token eagle spec decode on current platform"
)
m.setenv("VLLM_MLA_DISABLE", "1")
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
if "deepseek" in model_setup[1].lower():
@@ -471,7 +472,10 @@ def test_eagle_correctness(
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
ref_llm = LLM(
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
model=model_name,
max_model_len=max_model_len,
tensor_parallel_size=tp_size,
attention_config=attention_config,
)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
@@ -492,6 +496,7 @@ def test_eagle_correctness(
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=enable_chunked_prefill,
model_impl=model_impl,
attention_config=attention_config,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0