[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel (#23647)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
elvischenv
2025-09-09 11:53:07 +08:00
committed by GitHub
parent b6fbc15634
commit bba1042c6f
5 changed files with 22 additions and 11 deletions

View File

@@ -35,6 +35,7 @@ QUANT_DTYPES = [
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
(None, None, None),
(None, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, None),
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
]
@@ -44,6 +45,7 @@ NUM_HEADS = [(64, 8), (40, 8)]
HEAD_SIZE = [128]
KV_LAYOUT = ["HND"] # currently only HND is supported
BLOCK_SIZE = [16]
WINDOW_LEFT = [-1, 127]
SOFT_CAP = [None, 50.0]
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@@ -57,6 +59,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
@torch.inference_mode
def test_flashinfer_trtllm_decode_with_baseline(
@@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
head_size: int,
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
@@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
output = torch.empty(ref_query.shape, dtype=dtype)
@@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
max_seq_len=max_seq_len,
bmm1_scale=q_scale * k_scale * sm_scale,
bmm2_scale=v_scale / o_scale,
window_left=window_left,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
@pytest.mark.parametrize("head_size", HEAD_SIZE)
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
@pytest.mark.parametrize("soft_cap", [None])
@torch.inference_mode
def test_flashinfer_trtllm_prefill_with_baseline(
@@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
head_size: int,
kv_layout: str,
block_size: int,
window_left: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
@@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
sm_scale=sm_scale,
q_data_type=dtype,
kv_data_type=dtype,
window_left=window_left,
logits_soft_cap=soft_cap)
output = torch.empty(ref_query.shape, dtype=dtype)
@@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
batch_size=batch_size,
cum_seq_lens_q=q_indptr,
cum_seq_lens_kv=kv_indptr,
window_left=window_left,
o_sf_scale=o_sf_scale,
out=output_trtllm,
)
@@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
rtol, atol = 4e-1, 1e0
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
rtol, atol = 5e-2, 7e-2
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
rtol, atol = 4e-2, 6e-2
else:
rtol, atol = 1e-2, 1e-2