[Attention] Update tests to remove deprecated env vars (#30563)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -6,7 +6,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
@@ -73,18 +75,18 @@ def generate_params():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
|
||||
def test_env(
|
||||
def test_backend_selection(
|
||||
device: str,
|
||||
name: str,
|
||||
use_mla: bool,
|
||||
block_size: int,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", name)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
# Create AttentionConfig with the specified backend
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
@@ -217,27 +219,32 @@ def test_env(
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_fp32_fallback(device: str):
|
||||
"""Test attention backend selection with fp32."""
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
# Use default config (no backend specified)
|
||||
vllm_config = VllmConfig()
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FlashAttn validation."""
|
||||
pytest.skip(
|
||||
"Skipping as current backend selector does not "
|
||||
"handle fallbacks when a backend is set via env var."
|
||||
"handle fallbacks when a backend is explicitly set."
|
||||
)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Unsupported CUDA arch
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
@@ -277,15 +284,10 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
def test_invalid_backend():
|
||||
"""Test that invalid attention backend names raise ValueError."""
|
||||
with (
|
||||
monkeypatch.context() as m,
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()),
|
||||
pytest.raises(ValueError),
|
||||
):
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "INVALID")
|
||||
|
||||
# Should raise ValueError for invalid backend
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(32, torch.float16, None, 16)
|
||||
assert "Invalid value 'INVALID'" in str(exc_info.value)
|
||||
# Invalid backend name should raise ValueError when creating enum
|
||||
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
|
||||
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
|
||||
@@ -16,40 +18,56 @@ def clear_cache():
|
||||
|
||||
@pytest.mark.skip(reason="Skipped for now. Should be revisited.")
|
||||
def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_ATTN")
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
# Test standard ROCm attention
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_ATTN)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
# Test standard ROCm attention
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
|
||||
|
||||
# MLA test for deepseek related
|
||||
# MLA test for deepseek related
|
||||
# Change the attention backend to triton MLA
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.TRITON_MLA)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "TRITON_MLA")
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
attention_config = AttentionConfig(backend=None)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "ROCM_AITER_MLA")
|
||||
# Change the attention backend to AITER MLA
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.ROCM_AITER_MLA)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# If VLLM_ROCM_USE_AITER is enabled
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "")
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# If VLLM_ROCM_USE_AITER is enabled
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
attention_config = AttentionConfig(backend=None)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
backend = get_attn_backend(
|
||||
576, torch.bfloat16, "auto", 1, False, use_mla=True
|
||||
)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
@@ -37,7 +37,7 @@ def set_seed(seed):
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
def test_flex_attention_vs_default_backend(vllm_runner):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
@@ -54,35 +54,32 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
) as llm_flex:
|
||||
output_flex = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_flex:
|
||||
output_flex = llm_flex.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
# Run with default backend
|
||||
with monkeypatch.context() as m:
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85,
|
||||
) as llm_default:
|
||||
output_default = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
set_seed(seed)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="generate",
|
||||
tensor_parallel_size=1,
|
||||
num_gpu_blocks_override=128,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.85,
|
||||
) as llm_default:
|
||||
output_default = llm_default.generate_greedy_logprobs(
|
||||
prompts, max_tokens, num_logprobs
|
||||
)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=output_flex,
|
||||
@@ -96,7 +93,7 @@ def test_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
|
||||
reason="CUDA not available or PyTorch version < 2.7",
|
||||
)
|
||||
def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
def test_encoder_flex_attention_vs_default_backend(vllm_runner):
|
||||
"""Test that FlexAttention produces the same outputs as the default backend.
|
||||
|
||||
This test compares the outputs from the FlexAttention backend with
|
||||
@@ -110,30 +107,26 @@ def test_encoder_flex_attention_vs_default_backend(vllm_runner, monkeypatch):
|
||||
]
|
||||
|
||||
# Run with flex attention
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_flex:
|
||||
flex_outputs = llm_flex.embed(prompts)
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
attention_config={"backend": "FLEX_ATTENTION"},
|
||||
) as llm_flex:
|
||||
flex_outputs = llm_flex.embed(prompts)
|
||||
|
||||
# Run with default backend
|
||||
with (
|
||||
monkeypatch.context() as m,
|
||||
vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_default,
|
||||
):
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
dtype=torch.bfloat16,
|
||||
tensor_parallel_size=1,
|
||||
max_model_len=100,
|
||||
enforce_eager=True,
|
||||
) as llm_default:
|
||||
default_outputs = llm_default.embed(prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
|
||||
Reference in New Issue
Block a user