[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel (#23647)
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
This commit is contained in:
@@ -35,6 +35,7 @@ QUANT_DTYPES = [
|
||||
# (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
|
||||
(None, None, None),
|
||||
(None, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, None),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP8_DTYPE),
|
||||
(FP8_DTYPE, FP8_DTYPE, FP4_DTYPE),
|
||||
]
|
||||
@@ -44,6 +45,7 @@ NUM_HEADS = [(64, 8), (40, 8)]
|
||||
HEAD_SIZE = [128]
|
||||
KV_LAYOUT = ["HND"] # currently only HND is supported
|
||||
BLOCK_SIZE = [16]
|
||||
WINDOW_LEFT = [-1, 127]
|
||||
SOFT_CAP = [None, 50.0]
|
||||
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@@ -57,6 +59,7 @@ NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAP)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@@ -69,6 +72,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
@@ -155,6 +159,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
@@ -188,6 +193,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=q_scale * k_scale * sm_scale,
|
||||
bmm2_scale=v_scale / o_scale,
|
||||
window_left=window_left,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
@@ -222,6 +228,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZE)
|
||||
@pytest.mark.parametrize("kv_layout", KV_LAYOUT)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZE)
|
||||
@pytest.mark.parametrize("window_left", WINDOW_LEFT)
|
||||
@pytest.mark.parametrize("soft_cap", [None])
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
@@ -234,6 +241,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
head_size: int,
|
||||
kv_layout: str,
|
||||
block_size: int,
|
||||
window_left: int,
|
||||
soft_cap: Optional[float],
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
@@ -334,6 +342,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
sm_scale=sm_scale,
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
window_left=window_left,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = torch.empty(ref_query.shape, dtype=dtype)
|
||||
@@ -371,6 +380,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
batch_size=batch_size,
|
||||
cum_seq_lens_q=q_indptr,
|
||||
cum_seq_lens_kv=kv_indptr,
|
||||
window_left=window_left,
|
||||
o_sf_scale=o_sf_scale,
|
||||
out=output_trtllm,
|
||||
)
|
||||
@@ -390,6 +400,8 @@ def test_flashinfer_trtllm_prefill_with_baseline(
|
||||
rtol, atol = 4e-1, 1e0
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == FP8_DTYPE:
|
||||
rtol, atol = 5e-2, 7e-2
|
||||
elif q_quant_dtype == FP8_DTYPE and o_quant_dtype == dtype:
|
||||
rtol, atol = 4e-2, 6e-2
|
||||
else:
|
||||
rtol, atol = 1e-2, 1e-2
|
||||
|
||||
|
||||
Reference in New Issue
Block a user