[Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
|
||||
from vllm._custom_ops import (
|
||||
merge_attn_states as merge_attn_states_cuda,
|
||||
)
|
||||
from vllm._custom_ops import (
|
||||
scaled_fp8_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states as merge_attn_states_triton,
|
||||
@@ -21,6 +26,7 @@ def merge_attn_states_torch(
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
output_scale: torch.Tensor | None = None, # scalar, per-tensor FP8 scale
|
||||
):
|
||||
# Apply prefill_tokens_with_context mask if needed
|
||||
if prefill_tokens_with_context is None:
|
||||
@@ -49,9 +55,13 @@ def merge_attn_states_torch(
|
||||
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output.copy_(
|
||||
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
|
||||
output = prefix_output * p_scale * mask + suffix_output * (
|
||||
s_scale * mask + (1 - mask)
|
||||
)
|
||||
if output_scale is not None:
|
||||
shape = output.shape
|
||||
output, _ = scaled_fp8_quant(output.float().view(-1, shape[-1]), output_scale)
|
||||
output = output.view(shape)
|
||||
return output, output_lse
|
||||
|
||||
|
||||
@@ -102,18 +112,20 @@ def generate_markdown_table():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fp8", [False, True])
|
||||
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@pytest.mark.parametrize("input_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(
|
||||
prefill_tokens_with_context: int | None,
|
||||
num_tokens: int,
|
||||
num_query_heads: int,
|
||||
head_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
input_dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
):
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(
|
||||
@@ -125,9 +137,18 @@ def test_merge_attn_states(
|
||||
NUM_HEADS = num_query_heads
|
||||
HEAD_SIZE = head_size
|
||||
|
||||
# When use_fp8 is set, inputs stay as input_dtype (bf16/fp16/fp32)
|
||||
# and output becomes FP8.
|
||||
output_dtype = input_dtype
|
||||
output_scale = None
|
||||
if use_fp8:
|
||||
output_dtype = current_platform.fp8_dtype()
|
||||
output_scale = torch.tensor([0.05], dtype=torch.float32, device="cuda")
|
||||
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, input_dtype: {input_dtype}, "
|
||||
f"output_dtype: {output_dtype}, use_fp8: {use_fp8}, "
|
||||
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
|
||||
f"Device: {current_platform.get_device_name()}"
|
||||
)
|
||||
@@ -156,10 +177,10 @@ def test_merge_attn_states(
|
||||
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
prefix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
|
||||
warmup_times = 2
|
||||
@@ -183,6 +204,7 @@ def test_merge_attn_states(
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -196,6 +218,7 @@ def test_merge_attn_states(
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -220,6 +243,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -233,6 +257,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -254,6 +279,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -267,6 +293,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -288,7 +315,19 @@ def test_merge_attn_states(
|
||||
# Liger Kernel: Efficient Triton Kernels for LLM Training
|
||||
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
|
||||
# use rtol = 1e-2 for bfloat16.
|
||||
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
|
||||
if use_fp8:
|
||||
# Compare in dequantized space (multiply back by scale) so that
|
||||
# absolute differences reflect real precision, not amplified FP8
|
||||
# quantization steps.
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
assert output_scale is not None
|
||||
scale = output_scale.item()
|
||||
elif output_dtype == torch.bfloat16:
|
||||
atol, rtol = 1e-3, 1e-2
|
||||
scale = 1.0
|
||||
else:
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
scale = 1.0
|
||||
|
||||
def diff(a: torch.Tensor, b: torch.Tensor):
|
||||
max_diff = torch.max(torch.abs(a.float() - b.float()))
|
||||
@@ -300,16 +339,26 @@ def test_merge_attn_states(
|
||||
output_ref = output_ref_triton
|
||||
output_lse_ref = output_lse_ref_triton
|
||||
torch.testing.assert_close(
|
||||
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
|
||||
output_cuda.float() * scale,
|
||||
output_ref.float() * scale,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
print("Output all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
|
||||
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
|
||||
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
|
||||
print(
|
||||
"Output all match, max abs diff (dequantized):"
|
||||
if use_fp8
|
||||
else "Output all match, max abs diff:"
|
||||
)
|
||||
_diff = diff(output_ref.float() * scale, output_torch.float() * scale)
|
||||
print(f"(Triton vs Torch) : {_diff}")
|
||||
_diff = diff(output_torch.float() * scale, output_cuda.float() * scale)
|
||||
print(f" (CUDA vs Torch) : {_diff}")
|
||||
_diff = diff(output_ref.float() * scale, output_cuda.float() * scale)
|
||||
print(f" (CUDA vs Triton): {_diff}")
|
||||
print("-" * 100)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
|
||||
output_lse_cuda.float(), output_lse_ref.float(), atol=atol, rtol=rtol
|
||||
)
|
||||
print("Output LSE all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
|
||||
|
||||
Reference in New Issue
Block a user