[Attention] Update tests to remove deprecated env vars (#30563)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-12-17 12:49:59 -05:00
committed by GitHub
parent 9ca8cb38fd
commit 7eb6cb6c18
34 changed files with 580 additions and 447 deletions

View File

@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
)
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import (
AttentionConfig,
CacheConfig,
DeviceConfig,
ModelConfig,
@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
def _create_proposer(
method: str,
num_speculative_tokens: int,
attention_backend: str | None = None,
speculative_token_tree: list[tuple[int, ...]] | None = None,
) -> EagleProposer:
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
@@ -70,6 +72,7 @@ def _create_proposer(
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
attention_config=AttentionConfig(backend=attention_backend),
)
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
@@ -331,8 +334,6 @@ def test_load_model(
use_distinct_lm_head,
monkeypatch,
):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
@@ -394,7 +395,9 @@ def test_load_model(
assert not isinstance(target_model, SupportsMultiModal)
# Create proposer using the helper function
proposer = _create_proposer(method, num_speculative_tokens=8)
proposer = _create_proposer(
method, num_speculative_tokens=8, attention_backend=attn_backend
)
# Call the method under test
proposer.load_model(target_model)
@@ -420,8 +423,6 @@ def test_load_model(
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
pytest.skip(
"TRITON_ATTN does not support "
@@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
seq_lens = [seq_len_1, seq_len_2]
# Create proposer first so we can use its actual hidden_size
proposer = _create_proposer("eagle", num_speculative_tokens)
proposer = _create_proposer(
"eagle", num_speculative_tokens, attention_backend=attn_backend
)
# Get the hidden_size from the proposer to ensure consistency
hidden_size = proposer.hidden_size
@@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
# Create proposer first so we can use its actual hidden_size.
proposer = _create_proposer(
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
"eagle",
num_speculative_tokens,
speculative_token_tree=spec_token_tree,
)
# Get the hidden_size from the proposer to ensure consistency.
hidden_size = proposer.hidden_size