[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -35,10 +35,12 @@ audio_lora_path = MODEL_NAME
|
||||
models = [MODEL_NAME]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_attention_backend_for_rocm(monkeypatch):
|
||||
@pytest.fixture
|
||||
def granite_speech_attention_config():
|
||||
"""Return attention config for Granite Speech tests on ROCm."""
|
||||
if current_platform.is_rocm():
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "TRITON_ATTN")
|
||||
return {"backend": "TRITON_ATTN"}
|
||||
return None
|
||||
|
||||
|
||||
def run_test(
|
||||
@@ -53,6 +55,7 @@ def run_test(
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: str | None = None,
|
||||
attention_config: dict | None = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
@@ -80,6 +83,7 @@ def run_test(
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
enforce_eager=True,
|
||||
attention_config=attention_config,
|
||||
) as vllm_model:
|
||||
lora_request = LoRARequest("audio", 1, audio_lora_path)
|
||||
vllm_outputs_per_case = [
|
||||
@@ -131,6 +135,7 @@ def test_models(
|
||||
vllm_runner,
|
||||
model: str,
|
||||
audio_assets: AudioTestAssets,
|
||||
granite_speech_attention_config,
|
||||
dtype: str,
|
||||
max_model_len: int,
|
||||
max_tokens: int,
|
||||
@@ -157,4 +162,5 @@ def test_models(
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
attention_config=granite_speech_attention_config,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user