[Attention] Update tests to remove deprecated env vars (#30563)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-12-17 12:49:59 -05:00
committed by GitHub
parent 9ca8cb38fd
commit 7eb6cb6c18
34 changed files with 580 additions and 447 deletions

View File

@@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME
models = [MODEL_NAME]
@pytest.fixture(autouse=True)
def set_attention_backend_for_rocm(monkeypatch):
@pytest.fixture
def granite_speech_attention_config():
"""Return attention config for Granite Speech tests on ROCm."""
if current_platform.is_rocm():
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
return {"backend": "TRITON_ATTN"}
return None
def run_test(
@@ -53,6 +55,7 @@ def run_test(
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: str | None = None,
attention_config: dict | None = None,
):
"""Inference result should be the same between hf and vllm.
@@ -80,6 +83,7 @@ def run_test(
enable_lora=True,
max_lora_rank=64,
enforce_eager=True,
attention_config=attention_config,
) as vllm_model:
lora_request = LoRARequest("audio", 1, audio_lora_path)
vllm_outputs_per_case = [
@@ -131,6 +135,7 @@ def test_models(
vllm_runner,
model: str,
audio_assets: AudioTestAssets,
granite_speech_attention_config,
dtype: str,
max_model_len: int,
max_tokens: int,
@@ -157,4 +162,5 @@ def test_models(
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
attention_config=granite_speech_attention_config,
)