[Bugfix] Fix chunked prefill with model dtype float32 on Turing Devices (#9850)

Signed-off-by: Wallas Santos <wallashss@ibm.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
Wallas Henrique
2024-11-25 14:23:32 -03:00
committed by GitHub
parent d04b13a380
commit c27df94e1f
6 changed files with 122 additions and 13 deletions

View File

@@ -40,6 +40,13 @@ def test_contexted_kv_attention(
kv_cache_dtype: str,
device: str,
) -> None:
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89):
pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89')
current_platform.seed_everything(0)
torch.set_default_device(device)
@@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi(
kv_cache_dtype: str,
device: str,
) -> None:
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
89):
pytest.skip(
'Triton limitation: fp8e4nv data type is not supported on CUDA'
' arch < 89')
current_platform.seed_everything(0)
torch.set_default_device(device)
@@ -462,3 +476,52 @@ def test_contexted_kv_attention_alibi(
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
# These tests are optional to only run when explicitly invoked
#
# pytest -v -s --optional \
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
#
# These tests are useful to test model dtype float32 on Turing devices.
# We skip them to not increase the time when running tests on CI
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
@torch.inference_mode()
def test_contexted_kv_attention_f32(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
sliding_window: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
sliding_window, dtype, kv_cache_dtype, device)
@pytest.mark.optional
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention_alibi_f32(
num_heads: int,
num_queries_per_kv: int,
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str,
device: str,
) -> None:
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
dtype, kv_cache_dtype, device)