Reapply [Attention] Refactor check_and_update_config (#35122)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-03-09 10:17:14 -04:00
committed by GitHub
parent 5578f2a4d3
commit 77a73458e3
32 changed files with 311 additions and 279 deletions

View File

@@ -6,7 +6,12 @@ from unittest.mock import patch
import pytest
import torch
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
from vllm.config import (
AttentionConfig,
CacheConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.platforms import current_platform
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
@@ -84,12 +89,15 @@ def test_backend_selection(
"""Test attention backend selection with valid device-backend pairs."""
# Create AttentionConfig with the specified backend
attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
vllm_config = VllmConfig(attention_config=attention_config)
cache_config = CacheConfig(block_size=block_size)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with set_current_vllm_config(vllm_config):
if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float16, None, block_size)
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() == "CPU_ATTN"
elif device == "hip":
@@ -104,20 +112,16 @@ def test_backend_selection(
if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError):
get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
)
get_attn_backend(576, torch.float16, None, use_mla=use_mla)
else:
# Valid backend-block_size combination
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = name
assert backend.get_name() == expected
else:
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
)
backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
expected = "ROCM_ATTN"
assert backend.get_name() == expected
@@ -141,7 +145,7 @@ def test_backend_selection(
if capability[0] != 10:
pytest.skip("CUTLASS MLA is not supported on this platform")
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "CUTLASS_MLA"
assert backend.get_name() == expected
@@ -156,7 +160,7 @@ def test_backend_selection(
"FlashInfer MLA only supports block_size 32 or 64"
)
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
@@ -175,7 +179,6 @@ def test_backend_selection(
576,
torch.float16,
None,
block_size,
use_mla=use_mla,
)
expected = name
@@ -190,27 +193,23 @@ def test_backend_selection(
"FlashAttention MLA not supported on this platform"
)
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(
576, torch.float16, None, block_size, use_mla=use_mla
576, torch.float16, None, use_mla=use_mla
)
expected = "TRITON_MLA"
assert backend.get_name() == expected
elif name == "FLASHINFER":
backend = get_attn_backend(
64, torch.float16, None, block_size, use_mla=use_mla
)
backend = get_attn_backend(64, torch.float16, None, use_mla=use_mla)
expected = "FLASHINFER"
assert backend.get_name() == expected
elif name == "FLASH_ATTN":
backend = get_attn_backend(
32, torch.float16, None, block_size, use_mla=use_mla
)
backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
expected = "FLASH_ATTN"
assert backend.get_name() == expected
@@ -224,12 +223,12 @@ def test_fp32_fallback(device: str):
with set_current_vllm_config(vllm_config):
if device == "cpu":
with patch("vllm.platforms.current_platform", CpuPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "CPU_ATTN"
elif device == "cuda":
with patch("vllm.platforms.current_platform", CudaPlatform()):
backend = get_attn_backend(16, torch.float32, None, 16)
backend = get_attn_backend(16, torch.float32, None)
assert backend.get_name() == "FLEX_ATTENTION"
@@ -241,35 +240,40 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
)
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
vllm_config = VllmConfig(attention_config=attention_config)
cache_config = CacheConfig(block_size=16)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with set_current_vllm_config(vllm_config):
# Unsupported CUDA arch
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
backend = get_attn_backend(16, torch.float16, None, 16)
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
# Reset the monkeypatch for subsequent tests
monkeypatch.undo()
# Unsupported data type
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
backend = get_attn_backend(16, torch.float8_e4m3fn, None)
assert backend.get_name() != "FLASH_ATTN"
# Unsupported kv cache data type
backend = get_attn_backend(16, torch.float16, "fp8", 16)
backend = get_attn_backend(16, torch.float16, "fp8")
assert backend.get_name() != "FLASH_ATTN"
# Unsupported block size
backend = get_attn_backend(16, torch.float16, None, 8)
vllm_config.cache_config.block_size = 8
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
# flash-attn is not installed
import sys
vllm_config.cache_config.block_size = 16
original_module = sys.modules.get("vllm_flash_attn")
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
backend = get_attn_backend(16, torch.float16, None, 16)
backend = get_attn_backend(16, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
# Restore the original module if it existed
@@ -279,7 +283,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
# Unsupported head size
backend = get_attn_backend(17, torch.float16, None, 16)
backend = get_attn_backend(17, torch.float16, None)
assert backend.get_name() != "FLASH_ATTN"
@@ -320,7 +324,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config(vllm_config_auto),
patch("vllm.platforms.current_platform", CpuPlatform()),
):
backend_auto = get_attn_backend(16, torch.float16, None, 16)
backend_auto = get_attn_backend(16, torch.float16, None)
_cached_get_attn_backend.cache_clear()
@@ -328,7 +332,7 @@ def test_auto_backend_selection_behavior():
set_current_vllm_config(vllm_config_none),
patch("vllm.platforms.current_platform", CpuPlatform()),
):
backend_none = get_attn_backend(16, torch.float16, None, 16)
backend_none = get_attn_backend(16, torch.float16, None)
# Both should select the same backend
assert backend_auto.get_name() == backend_none.get_name()
@@ -358,7 +362,10 @@ def test_per_head_quant_scales_backend_selection(
backend=AttentionBackendEnum[backend_name],
flash_attn_version=flash_attn_version,
)
vllm_config = VllmConfig(attention_config=attention_config)
cache_config = CacheConfig(block_size=64)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)
with (
set_current_vllm_config(vllm_config),
@@ -376,7 +383,6 @@ def test_per_head_quant_scales_backend_selection(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend.get_name() == backend_name
@@ -386,7 +392,6 @@ def test_per_head_quant_scales_backend_selection(
head_size=128,
dtype=torch.float16,
kv_cache_dtype="fp8",
block_size=64,
use_per_head_quant_scales=True,
)
assert backend_name in str(exc_info.value)