Add unit tests for fp8 output fusion of triton_attn (#34228)
Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
c50e105a88
commit
e24663c5a9
@@ -10,7 +10,7 @@ from vllm.utils.math_utils import next_power_of_2
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.ops.triton_unified_attention import unified_attention
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
NUM_HEADS = [(4, 4), (8, 2), (5, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16]
|
||||
|
||||
@@ -20,6 +20,8 @@ QDTYPES = (
|
||||
if not current_platform.is_rocm()
|
||||
else [None, torch.float8_e4m3fnuz]
|
||||
)
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
# one value large enough to test overflow in index calculation.
|
||||
# one value small enough to test the schema op check
|
||||
NUM_BLOCKS = [32768, 2048]
|
||||
@@ -217,3 +219,127 @@ def test_triton_unified_attn(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens",
|
||||
[
|
||||
[(1, 1328), (5, 18), (129, 463)],
|
||||
[(1, 523), (1, 37), (1, 2011)],
|
||||
[(1, 1)] * 533,
|
||||
[(533, 533)] * 533,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256])
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("seq_threshold_3D", SEQ_THRESHOLD_3D_VALUES)
|
||||
@torch.inference_mode()
|
||||
def test_triton_unified_attn_fp16_input_fp8_output(
|
||||
seq_lens: list[tuple[int, int]],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
sliding_window: int | None,
|
||||
block_size: int,
|
||||
soft_cap: float | None,
|
||||
num_blocks: int,
|
||||
seq_threshold_3D: int,
|
||||
) -> None:
|
||||
"""Test with fp16 input and fp8 output using output_scale."""
|
||||
torch.set_default_device("cuda")
|
||||
|
||||
set_random_seed(0)
|
||||
num_seqs = len(seq_lens)
|
||||
query_lens = [x[0] for x in seq_lens]
|
||||
kv_lens = [x[1] for x in seq_lens]
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_query_len = max(query_lens)
|
||||
max_kv_len = max(kv_lens)
|
||||
window_size = (sliding_window - 1, 0) if sliding_window is not None else (-1, -1)
|
||||
scale = head_size**-0.5
|
||||
|
||||
dtype = torch.float16
|
||||
query = torch.randn(sum(query_lens), num_query_heads, head_size, dtype=dtype)
|
||||
key_cache = torch.randn(
|
||||
num_blocks, block_size, num_kv_heads, head_size, dtype=dtype
|
||||
)
|
||||
value_cache = torch.randn_like(key_cache)
|
||||
cu_query_lens = torch.tensor([0] + query_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(
|
||||
0, num_blocks, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
output = torch.empty(sum(query_lens), num_query_heads, head_size, dtype=FP8_DTYPE)
|
||||
|
||||
output_scale = torch.tensor(0.5, dtype=torch.float32)
|
||||
|
||||
num_par_softmax_segments = 16
|
||||
head_size_padded = next_power_of_2(head_size)
|
||||
softmax_segm_output = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments, head_size_padded),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_max = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
softmax_segm_expsum = torch.empty(
|
||||
(seq_threshold_3D, num_query_heads, num_par_softmax_segments),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
unified_attention(
|
||||
q=query,
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
seqused_k=kv_lens_tensor,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
softmax_scale=scale,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
softcap=soft_cap if soft_cap is not None else 0,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
output_scale=output_scale,
|
||||
seq_threshold_3D=seq_threshold_3D,
|
||||
num_par_softmax_segments=num_par_softmax_segments,
|
||||
softmax_segm_output=softmax_segm_output,
|
||||
softmax_segm_max=softmax_segm_max,
|
||||
softmax_segm_expsum=softmax_segm_expsum,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
query=query,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
soft_cap=soft_cap,
|
||||
)
|
||||
|
||||
output_fp16 = output.to(torch.float32) * output_scale.item()
|
||||
output_fp16 = output_fp16.to(torch.float16)
|
||||
|
||||
atol, rtol = 2e-1, 2e-1
|
||||
(
|
||||
torch.testing.assert_close(output_fp16, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output_fp16 - ref_output))}",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user