[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) (#4888)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
afeldman-nm
2024-07-08 13:12:15 -04:00
committed by GitHub
parent f7a8fa39d8
commit 543aa48573
14 changed files with 2351 additions and 95 deletions

View File

@@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
assert backend.name != "FLASH_ATTN"
assert backend.name != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch):