[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -142,16 +142,17 @@ def run_tests(
|
||||
"""Test consistency of combos of async scheduling, preemption,
|
||||
uni/multiproc executor with spec decoding."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# avoid precision errors
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_FA")
|
||||
# Determine attention config based on platform
|
||||
if current_platform.is_rocm():
|
||||
if is_testing_with_spec_decoding:
|
||||
# Use TRITON_ATTN for spec decoding test for consistency
|
||||
attention_config = {"backend": "TRITON_ATTN"}
|
||||
else:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
attention_config = {"backend": "ROCM_AITER_FA"}
|
||||
else:
|
||||
attention_config = {"backend": "FLEX_ATTENTION"}
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
# lock matmul precision to full FP32 (IEEE)
|
||||
m.setenv("VLLM_FLOAT32_MATMUL_PRECISION", "ieee")
|
||||
# m.setenv("VLLM_BATCH_INVARIANT", "1")
|
||||
@@ -174,6 +175,7 @@ def run_tests(
|
||||
spec_config,
|
||||
test_prefill_chunking=test_prefill_chunking,
|
||||
is_testing_with_spec_decoding=is_testing_with_spec_decoding,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
outputs.append(test_results)
|
||||
|
||||
@@ -262,6 +264,7 @@ def run_test(
|
||||
spec_config: dict[str, Any] | None,
|
||||
test_prefill_chunking: bool,
|
||||
is_testing_with_spec_decoding: bool = False,
|
||||
attention_config: dict[str, Any] | None = None,
|
||||
):
|
||||
spec_decoding = spec_config is not None
|
||||
cache_arg: dict[str, Any] = (
|
||||
@@ -301,6 +304,7 @@ def run_test(
|
||||
dtype=dtype,
|
||||
speculative_config=spec_config,
|
||||
disable_log_stats=False,
|
||||
attention_config=attention_config,
|
||||
**cache_arg,
|
||||
) as vllm_model:
|
||||
results = []
|
||||
|
||||
@@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test
|
||||
|
||||
@create_new_process_for_each_test()
|
||||
@pytest.mark.parametrize("attn_backend", ["FLASH_ATTN", "FLASHINFER"])
|
||||
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||
def test_cascade_attention(example_system_message, attn_backend):
|
||||
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
||||
|
||||
if attn_backend == "FLASHINFER":
|
||||
@@ -19,19 +19,18 @@ def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||
"needs investigation. See issue #25679."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
llm = LLM(
|
||||
model="Qwen/Qwen2-1.5B-Instruct", attention_config={"backend": attn_backend}
|
||||
)
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
# No cascade attention.
|
||||
single_prompt = [example_system_message + prompt]
|
||||
responses = llm.generate(single_prompt, sampling_params)
|
||||
ref_output = responses[0].outputs[0].text
|
||||
|
||||
# No cascade attention.
|
||||
single_prompt = [example_system_message + prompt]
|
||||
responses = llm.generate(single_prompt, sampling_params)
|
||||
ref_output = responses[0].outputs[0].text
|
||||
|
||||
# (Probably) Use cascade attention.
|
||||
prompts = [example_system_message + prompt] * 64
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
for response in responses:
|
||||
assert response.outputs[0].text == ref_output
|
||||
# (Probably) Use cascade attention.
|
||||
prompts = [example_system_message + prompt] * 64
|
||||
responses = llm.generate(prompts, sampling_params)
|
||||
for response in responses:
|
||||
assert response.outputs[0].text == ref_output
|
||||
|
||||
@@ -438,25 +438,26 @@ def test_eagle_correctness(
|
||||
should be the same when using eagle speculative decoding.
|
||||
model_setup: (method, model_name, eagle_model_name, tp_size)
|
||||
"""
|
||||
# Determine attention config
|
||||
# Scout requires default backend selection because vision encoder has
|
||||
# head_dim 88 being incompatible with FLASH_ATTN and needs to fall back
|
||||
# to Flex Attn
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
if current_platform.is_rocm():
|
||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||
attention_config = None # Let it fall back to default
|
||||
else:
|
||||
attention_config = {"backend": attn_backend}
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
if "Llama-4-Scout" in model_setup[1] and attn_backend == "FLASH_ATTN":
|
||||
# Scout requires default backend selection
|
||||
# because vision encoder has head_dim 88 being incompatible
|
||||
# with FLASH_ATTN and needs to fall back to Flex Attn
|
||||
|
||||
# pass if not ROCm
|
||||
if current_platform.is_rocm():
|
||||
# TODO: Enable Flex Attn for spec_decode on ROCm
|
||||
pytest.skip("Flex Attn for spec_decode not supported on ROCm currently")
|
||||
else:
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
"multi-token eagle spec decode on current platform"
|
||||
)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1")
|
||||
|
||||
if attn_backend == "ROCM_AITER_FA" and current_platform.is_rocm():
|
||||
if "deepseek" in model_setup[1].lower():
|
||||
@@ -471,7 +472,10 @@ def test_eagle_correctness(
|
||||
max_num_batched_tokens = 128 if enable_chunked_prefill else max_model_len
|
||||
|
||||
ref_llm = LLM(
|
||||
model=model_name, max_model_len=max_model_len, tensor_parallel_size=tp_size
|
||||
model=model_name,
|
||||
max_model_len=max_model_len,
|
||||
tensor_parallel_size=tp_size,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
|
||||
del ref_llm
|
||||
@@ -492,6 +496,7 @@ def test_eagle_correctness(
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
model_impl=model_impl,
|
||||
attention_config=attention_config,
|
||||
)
|
||||
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
|
||||
matches = 0
|
||||
|
||||
Reference in New Issue
Block a user