[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -15,6 +15,7 @@ from tests.v1.attention.utils import (
|
||||
)
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
DeviceConfig,
|
||||
ModelConfig,
|
||||
@@ -38,6 +39,7 @@ eagle3_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
def _create_proposer(
|
||||
method: str,
|
||||
num_speculative_tokens: int,
|
||||
attention_backend: str | None = None,
|
||||
speculative_token_tree: list[tuple[int, ...]] | None = None,
|
||||
) -> EagleProposer:
|
||||
model_config = ModelConfig(model=model_dir, runner="generate", max_model_len=100)
|
||||
@@ -70,6 +72,7 @@ def _create_proposer(
|
||||
max_model_len=model_config.max_model_len,
|
||||
is_encoder_decoder=model_config.is_encoder_decoder,
|
||||
),
|
||||
attention_config=AttentionConfig(backend=attention_backend),
|
||||
)
|
||||
|
||||
return EagleProposer(vllm_config=vllm_config, device=current_platform.device_type)
|
||||
@@ -331,8 +334,6 @@ def test_load_model(
|
||||
use_distinct_lm_head,
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
@@ -394,7 +395,9 @@ def test_load_model(
|
||||
assert not isinstance(target_model, SupportsMultiModal)
|
||||
|
||||
# Create proposer using the helper function
|
||||
proposer = _create_proposer(method, num_speculative_tokens=8)
|
||||
proposer = _create_proposer(
|
||||
method, num_speculative_tokens=8, attention_backend=attn_backend
|
||||
)
|
||||
|
||||
# Call the method under test
|
||||
proposer.load_model(target_model)
|
||||
@@ -420,8 +423,6 @@ def test_load_model(
|
||||
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
|
||||
@pytest.mark.parametrize("num_speculative_tokens", [1, 3, 8])
|
||||
def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
if attn_backend == "TRITON_ATTN" and not current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"TRITON_ATTN does not support "
|
||||
@@ -449,7 +450,9 @@ def test_propose(method, attn_backend, num_speculative_tokens, monkeypatch):
|
||||
seq_lens = [seq_len_1, seq_len_2]
|
||||
|
||||
# Create proposer first so we can use its actual hidden_size
|
||||
proposer = _create_proposer("eagle", num_speculative_tokens)
|
||||
proposer = _create_proposer(
|
||||
"eagle", num_speculative_tokens, attention_backend=attn_backend
|
||||
)
|
||||
# Get the hidden_size from the proposer to ensure consistency
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
@@ -622,7 +625,9 @@ def test_propose_tree(spec_token_tree):
|
||||
|
||||
# Create proposer first so we can use its actual hidden_size.
|
||||
proposer = _create_proposer(
|
||||
"eagle", num_speculative_tokens, speculative_token_tree=spec_token_tree
|
||||
"eagle",
|
||||
num_speculative_tokens,
|
||||
speculative_token_tree=spec_token_tree,
|
||||
)
|
||||
# Get the hidden_size from the proposer to ensure consistency.
|
||||
hidden_size = proposer.hidden_size
|
||||
|
||||
Reference in New Issue
Block a user