Reapply [Attention] Refactor check_and_update_config (#35122)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -6,7 +6,12 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.config import (
|
||||
AttentionConfig,
|
||||
CacheConfig,
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.cpu import CpuPlatform
|
||||
from vllm.platforms.cuda import CudaPlatform
|
||||
@@ -84,12 +89,15 @@ def test_backend_selection(
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
# Create AttentionConfig with the specified backend
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
cache_config = CacheConfig(block_size=block_size)
|
||||
vllm_config = VllmConfig(
|
||||
attention_config=attention_config, cache_config=cache_config
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, None, block_size)
|
||||
backend = get_attn_backend(16, torch.float16, None)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "hip":
|
||||
@@ -104,20 +112,16 @@ def test_backend_selection(
|
||||
if name == "TRITON_MLA" and block_size == 1:
|
||||
# TRITON_MLA doesn't support block_size == 1
|
||||
with pytest.raises(ValueError):
|
||||
get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
get_attn_backend(576, torch.float16, None, use_mla=use_mla)
|
||||
else:
|
||||
# Valid backend-block_size combination
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
32, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
|
||||
expected = "ROCM_ATTN"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
@@ -141,7 +145,7 @@ def test_backend_selection(
|
||||
if capability[0] != 10:
|
||||
pytest.skip("CUTLASS MLA is not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
@@ -156,7 +160,7 @@ def test_backend_selection(
|
||||
"FlashInfer MLA only supports block_size 32 or 64"
|
||||
)
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
@@ -175,7 +179,6 @@ def test_backend_selection(
|
||||
576,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
expected = name
|
||||
@@ -190,27 +193,23 @@ def test_backend_selection(
|
||||
"FlashAttention MLA not supported on this platform"
|
||||
)
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, use_mla=use_mla
|
||||
)
|
||||
expected = "TRITON_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER":
|
||||
backend = get_attn_backend(
|
||||
64, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
backend = get_attn_backend(64, torch.float16, None, use_mla=use_mla)
|
||||
expected = "FLASHINFER"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN":
|
||||
backend = get_attn_backend(
|
||||
32, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
backend = get_attn_backend(32, torch.float16, None, use_mla=use_mla)
|
||||
expected = "FLASH_ATTN"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
@@ -224,12 +223,12 @@ def test_fp32_fallback(device: str):
|
||||
with set_current_vllm_config(vllm_config):
|
||||
if device == "cpu":
|
||||
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
backend = get_attn_backend(16, torch.float32, None)
|
||||
assert backend.get_name() == "CPU_ATTN"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
backend = get_attn_backend(16, torch.float32, None)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
|
||||
@@ -241,35 +240,40 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
)
|
||||
|
||||
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
cache_config = CacheConfig(block_size=16)
|
||||
vllm_config = VllmConfig(
|
||||
attention_config=attention_config, cache_config=cache_config
|
||||
)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
# Unsupported CUDA arch
|
||||
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
backend = get_attn_backend(16, torch.float16, None)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Reset the monkeypatch for subsequent tests
|
||||
monkeypatch.undo()
|
||||
|
||||
# Unsupported data type
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
||||
backend = get_attn_backend(16, torch.float8_e4m3fn, None)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
||||
backend = get_attn_backend(16, torch.float16, "fp8")
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Unsupported block size
|
||||
backend = get_attn_backend(16, torch.float16, None, 8)
|
||||
vllm_config.cache_config.block_size = 8
|
||||
backend = get_attn_backend(16, torch.float16, None)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# flash-attn is not installed
|
||||
import sys
|
||||
|
||||
vllm_config.cache_config.block_size = 16
|
||||
original_module = sys.modules.get("vllm_flash_attn")
|
||||
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
|
||||
backend = get_attn_backend(16, torch.float16, None, 16)
|
||||
backend = get_attn_backend(16, torch.float16, None)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
# Restore the original module if it existed
|
||||
@@ -279,7 +283,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
|
||||
|
||||
# Unsupported head size
|
||||
backend = get_attn_backend(17, torch.float16, None, 16)
|
||||
backend = get_attn_backend(17, torch.float16, None)
|
||||
assert backend.get_name() != "FLASH_ATTN"
|
||||
|
||||
|
||||
@@ -320,7 +324,7 @@ def test_auto_backend_selection_behavior():
|
||||
set_current_vllm_config(vllm_config_auto),
|
||||
patch("vllm.platforms.current_platform", CpuPlatform()),
|
||||
):
|
||||
backend_auto = get_attn_backend(16, torch.float16, None, 16)
|
||||
backend_auto = get_attn_backend(16, torch.float16, None)
|
||||
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
@@ -328,7 +332,7 @@ def test_auto_backend_selection_behavior():
|
||||
set_current_vllm_config(vllm_config_none),
|
||||
patch("vllm.platforms.current_platform", CpuPlatform()),
|
||||
):
|
||||
backend_none = get_attn_backend(16, torch.float16, None, 16)
|
||||
backend_none = get_attn_backend(16, torch.float16, None)
|
||||
|
||||
# Both should select the same backend
|
||||
assert backend_auto.get_name() == backend_none.get_name()
|
||||
@@ -358,7 +362,10 @@ def test_per_head_quant_scales_backend_selection(
|
||||
backend=AttentionBackendEnum[backend_name],
|
||||
flash_attn_version=flash_attn_version,
|
||||
)
|
||||
vllm_config = VllmConfig(attention_config=attention_config)
|
||||
cache_config = CacheConfig(block_size=64)
|
||||
vllm_config = VllmConfig(
|
||||
attention_config=attention_config, cache_config=cache_config
|
||||
)
|
||||
|
||||
with (
|
||||
set_current_vllm_config(vllm_config),
|
||||
@@ -376,7 +383,6 @@ def test_per_head_quant_scales_backend_selection(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="fp8",
|
||||
block_size=64,
|
||||
use_per_head_quant_scales=True,
|
||||
)
|
||||
assert backend.get_name() == backend_name
|
||||
@@ -386,7 +392,6 @@ def test_per_head_quant_scales_backend_selection(
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
kv_cache_dtype="fp8",
|
||||
block_size=64,
|
||||
use_per_head_quant_scales=True,
|
||||
)
|
||||
assert backend_name in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user