[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from contextlib import ExitStack
|
||||
|
||||
@@ -13,26 +11,6 @@ from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def temporary_environ(env_vars):
|
||||
"""
|
||||
Temporarily set environment variables and restore them afterward.
|
||||
We have to do this vs monkeypatch because monkeypatch doesn't work
|
||||
with "module" scoped fixtures.
|
||||
"""
|
||||
original_env = {k: os.environ.get(k) for k in env_vars}
|
||||
try:
|
||||
os.environ.update(env_vars)
|
||||
yield
|
||||
finally:
|
||||
for k, v in original_env.items():
|
||||
if v is None:
|
||||
os.environ.pop(k, None)
|
||||
else:
|
||||
os.environ[k] = v
|
||||
|
||||
|
||||
# test attention backend and cudagraph_mode combo
|
||||
# (backend_name, cudagraph_mode, supported)
|
||||
if current_platform.is_rocm():
|
||||
@@ -68,9 +46,9 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
||||
):
|
||||
pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
|
||||
|
||||
env_vars = backend_configs[backend_name].env_vars
|
||||
attention_config = backend_config.attention_config
|
||||
|
||||
with temporary_environ(env_vars), ExitStack() as stack:
|
||||
with ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(Exception))
|
||||
|
||||
@@ -80,6 +58,7 @@ def test_backend_and_cudagraph_mode_combo(backend_name, cudagraph_mode, supporte
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
attention_config=attention_config,
|
||||
compilation_config=CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE, cudagraph_mode=cudagraph_mode
|
||||
),
|
||||
@@ -122,9 +101,10 @@ combo_cases_2 = [
|
||||
def test_cudagraph_compilation_combo(
|
||||
backend_name, cudagraph_mode, compilation_mode, supported
|
||||
):
|
||||
env_vars = backend_configs[backend_name].env_vars
|
||||
backend_config = backend_configs[backend_name]
|
||||
attention_config = backend_config.attention_config
|
||||
|
||||
with temporary_environ(env_vars), ExitStack() as stack:
|
||||
with ExitStack() as stack:
|
||||
if not supported:
|
||||
stack.enter_context(pytest.raises(Exception))
|
||||
|
||||
@@ -134,6 +114,7 @@ def test_cudagraph_compilation_combo(
|
||||
trust_remote_code=True,
|
||||
gpu_memory_utilization=0.45,
|
||||
max_model_len=1024,
|
||||
attention_config=attention_config,
|
||||
compilation_config=CompilationConfig(
|
||||
mode=compilation_mode, cudagraph_mode=cudagraph_mode
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user