[Kernel] Correctly invoke prefill & decode kernels for cross-attention (towards eventual encoder/decoder model support) (#4888)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -47,32 +47,32 @@ def test_flash_attn(monkeypatch):
|
||||
# Unsupported CUDA arch
|
||||
with patch("torch.cuda.get_device_capability", return_value=[7, 5]):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported data type
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float8_e4m3fn, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported kv cache data type
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, "fp8", 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported block size
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 8)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported sliding window
|
||||
backend = which_attn_to_use(8, 16, 8, 1, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# flash-attn is not installed
|
||||
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
|
||||
backend = which_attn_to_use(8, 16, 8, None, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
# Unsupported head size
|
||||
backend = which_attn_to_use(8, 17, 8, None, torch.float16, None, 16)
|
||||
assert backend.name != "FLASH_ATTN"
|
||||
assert backend.name != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch):
|
||||
|
||||
Reference in New Issue
Block a user