[Bugfix] Fix chunked prefill with model dtype float32 on Turing Devices (#9850)
Signed-off-by: Wallas Santos <wallashss@ibm.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -40,6 +40,13 @@ def test_contexted_kv_attention(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
|
||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||
89):
|
||||
pytest.skip(
|
||||
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||
' arch < 89')
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi(
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
|
||||
if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability(
|
||||
89):
|
||||
pytest.skip(
|
||||
'Triton limitation: fp8e4nv data type is not supported on CUDA'
|
||||
' arch < 89')
|
||||
|
||||
current_platform.seed_everything(0)
|
||||
torch.set_default_device(device)
|
||||
|
||||
@@ -462,3 +476,52 @@ def test_contexted_kv_attention_alibi(
|
||||
print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms")
|
||||
atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6
|
||||
torch.testing.assert_close(output, output_ref, atol=atol, rtol=0)
|
||||
|
||||
|
||||
# These tests are optional to only run when explicitly invoked
|
||||
#
|
||||
# pytest -v -s --optional \
|
||||
# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32
|
||||
#
|
||||
# These tests are useful to test model dtype float32 on Turing devices.
|
||||
# We skip them to not increase the time when running tests on CI
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_f32(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
sliding_window: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size,
|
||||
sliding_window, dtype, kv_cache_dtype, device)
|
||||
|
||||
|
||||
@pytest.mark.optional
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", [torch.float32])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention_alibi_f32(
|
||||
num_heads: int,
|
||||
num_queries_per_kv: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size,
|
||||
dtype, kv_cache_dtype, device)
|
||||
|
||||
Reference in New Issue
Block a user