[v1] Add fp32 support to v1 engine through flex attn (#19319)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -183,6 +183,34 @@ def test_env(
|
||||
assert backend.get_name() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
def test_fp32_fallback(
|
||||
device: str,
|
||||
use_v1: bool,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test attention backend selection with fp32."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
||||
16, False)
|
||||
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
if use_v1 else "TORCH_SDPA")
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
||||
16, False)
|
||||
assert (backend.get_name() == "FLEX_ATTENTION"
|
||||
if use_v1 else "XFORMERS")
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FlashAttn validation."""
|
||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||
|
||||
Reference in New Issue
Block a user