[Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
This commit is contained in:
Carl Y
2026-04-02 18:47:04 -07:00
committed by GitHub
parent 1f5ec2889c
commit 3bc2734dd0
8 changed files with 516 additions and 70 deletions

View File

@@ -4,7 +4,12 @@
import pytest
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm._custom_ops import (
merge_attn_states as merge_attn_states_cuda,
)
from vllm._custom_ops import (
scaled_fp8_quant,
)
from vllm.platforms import current_platform
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
@@ -21,6 +26,7 @@ def merge_attn_states_torch(
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
prefill_tokens_with_context: int | None = None,
output_scale: torch.Tensor | None = None, # scalar, per-tensor FP8 scale
):
# Apply prefill_tokens_with_context mask if needed
if prefill_tokens_with_context is None:
@@ -49,9 +55,13 @@ def merge_attn_states_torch(
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output.copy_(
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
output = prefix_output * p_scale * mask + suffix_output * (
s_scale * mask + (1 - mask)
)
if output_scale is not None:
shape = output.shape
output, _ = scaled_fp8_quant(output.float().view(-1, shape[-1]), output_scale)
output = output.view(shape)
return output, output_lse
@@ -102,18 +112,20 @@ def generate_markdown_table():
)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@pytest.mark.parametrize("input_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(
prefill_tokens_with_context: int | None,
num_tokens: int,
num_query_heads: int,
head_size: int,
output_dtype: torch.dtype,
input_dtype: torch.dtype,
use_fp8: bool,
):
if not current_platform.is_cuda():
pytest.skip(
@@ -125,9 +137,18 @@ def test_merge_attn_states(
NUM_HEADS = num_query_heads
HEAD_SIZE = head_size
# When use_fp8 is set, inputs stay as input_dtype (bf16/fp16/fp32)
# and output becomes FP8.
output_dtype = input_dtype
output_scale = None
if use_fp8:
output_dtype = current_platform.fp8_dtype()
output_scale = torch.tensor([0.05], dtype=torch.float32, device="cuda")
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"HEAD_SIZE:{HEAD_SIZE}, input_dtype: {input_dtype}, "
f"output_dtype: {output_dtype}, use_fp8: {use_fp8}, "
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
f"Device: {current_platform.get_device_name()}"
)
@@ -156,10 +177,10 @@ def test_merge_attn_states(
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
)
prefix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
)
suffix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
)
warmup_times = 2
@@ -183,6 +204,7 @@ def test_merge_attn_states(
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
output_scale,
)
torch.accelerator.synchronize()
@@ -196,6 +218,7 @@ def test_merge_attn_states(
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
output_scale,
)
end.record()
torch.accelerator.synchronize()
@@ -220,6 +243,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
output_scale,
)
torch.accelerator.synchronize()
@@ -233,6 +257,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
output_scale,
)
end.record()
torch.accelerator.synchronize()
@@ -254,6 +279,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
output_scale,
)
torch.accelerator.synchronize()
@@ -267,6 +293,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
output_scale,
)
end.record()
torch.accelerator.synchronize()
@@ -288,7 +315,19 @@ def test_merge_attn_states(
# Liger Kernel: Efficient Triton Kernels for LLM Training
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
# use rtol = 1e-2 for bfloat16.
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
if use_fp8:
# Compare in dequantized space (multiply back by scale) so that
# absolute differences reflect real precision, not amplified FP8
# quantization steps.
atol, rtol = 1e-1, 1e-1
assert output_scale is not None
scale = output_scale.item()
elif output_dtype == torch.bfloat16:
atol, rtol = 1e-3, 1e-2
scale = 1.0
else:
atol, rtol = 1e-3, 1e-3
scale = 1.0
def diff(a: torch.Tensor, b: torch.Tensor):
max_diff = torch.max(torch.abs(a.float() - b.float()))
@@ -300,16 +339,26 @@ def test_merge_attn_states(
output_ref = output_ref_triton
output_lse_ref = output_lse_ref_triton
torch.testing.assert_close(
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
output_cuda.float() * scale,
output_ref.float() * scale,
atol=atol,
rtol=rtol,
)
print("Output all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
print(
"Output all match, max abs diff (dequantized):"
if use_fp8
else "Output all match, max abs diff:"
)
_diff = diff(output_ref.float() * scale, output_torch.float() * scale)
print(f"(Triton vs Torch) : {_diff}")
_diff = diff(output_torch.float() * scale, output_cuda.float() * scale)
print(f" (CUDA vs Torch) : {_diff}")
_diff = diff(output_ref.float() * scale, output_cuda.float() * scale)
print(f" (CUDA vs Triton): {_diff}")
print("-" * 100)
torch.testing.assert_close(
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
output_lse_cuda.float(), output_lse_ref.float(), atol=atol, rtol=rtol
)
print("Output LSE all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")