[Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)

Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
This commit is contained in:
Carl Y
2026-04-02 18:47:04 -07:00
committed by GitHub
parent 1f5ec2889c
commit 3bc2734dd0
8 changed files with 516 additions and 70 deletions

View File

@@ -0,0 +1,264 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark: Fused FP8 output quantization in merge_attn_states
Compares fused vs unfused approaches for producing FP8-quantized merged
attention output:
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
Usage:
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
"""
import argparse
import itertools
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.triton_utils import triton
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
)
# ---------------------------------------------------------------------------
# Configuration defaults
# ---------------------------------------------------------------------------
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
# (label, num_heads, head_size) — num_heads is for TP=1
HEAD_CONFIGS = [
("DeepSeek-V3 MLA", 128, 128),
("Llama-70B", 64, 128),
("Llama-8B", 32, 128),
]
TP_SIZES = [1, 2, 4, 8]
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
QUANTILES = [0.5, 0.2, 0.8]
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def short_dtype(dtype: torch.dtype) -> str:
return str(dtype).removeprefix("torch.")
def make_inputs(
num_tokens: int,
num_heads: int,
head_size: int,
dtype: torch.dtype,
):
"""Create random prefix/suffix outputs and LSEs."""
prefix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
suffix_output = torch.randn(
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
)
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
# Sprinkle some inf values to exercise edge-case paths
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
prefix_lse[mask] = float("inf")
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
suffix_lse[mask2] = float("inf")
return prefix_output, suffix_output, prefix_lse, suffix_lse
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
applying TP division to num_heads and skipping invalid combos."""
configs = []
for (_, nh, hs), nt, dtype, tp in itertools.product(
head_configs, num_tokens_list, input_dtypes, tp_sizes
):
nh_tp = nh // tp
if nh_tp >= 1:
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
return configs
def parse_args():
parser = argparse.ArgumentParser(
description="Benchmark merge_attn_states fused FP8 quantization"
)
parser.add_argument(
"--num-tokens",
type=int,
nargs="+",
default=None,
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
)
parser.add_argument(
"--tp",
type=int,
nargs="+",
default=None,
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
)
parser.add_argument(
"--dtype",
type=str,
nargs="+",
default=None,
help="Input dtypes (e.g. bfloat16 float16 float32). "
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
)
return parser.parse_args()
# ---------------------------------------------------------------------------
# Parse args and build configs before decorators
# ---------------------------------------------------------------------------
args = parse_args()
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
tp_sizes = args.tp if args.tp else TP_SIZES
if args.dtype:
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
else:
input_dtypes = INPUT_DTYPES
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
torch._dynamo.config.recompile_limit = 8888
# ---------------------------------------------------------------------------
# Benchmark function
# ---------------------------------------------------------------------------
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
x_vals=configs,
line_arg="provider",
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
ylabel="us",
plot_name="merge_attn_states FP8 (fused vs unfused)",
args={},
)
)
@default_vllm_config()
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
input_dtype = getattr(torch, dtype_str)
fp8_dtype = current_platform.fp8_dtype()
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
num_tokens, num_heads, head_size, input_dtype
)
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
if provider == "fused_cuda":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_cuda(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "fused_triton":
output = torch.empty(
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
)
fn = lambda: merge_attn_states_triton(
output,
prefix_out,
prefix_lse,
suffix_out,
suffix_lse,
output_scale=output_scale,
)
elif provider == "unfused_cuda":
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_cuda(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
else: # unfused_triton
merge_buf = torch.empty(
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
)
quant_fp8 = QuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
column_major_scales=False,
)
quant_input = merge_buf.view(-1, head_size)
compiled_quant = torch.compile(
quant_fp8.forward_native, fullgraph=True, dynamic=False
)
def unfused_fn():
merge_attn_states_triton(
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
)
compiled_quant(quant_input, output_scale)
fn = unfused_fn
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
device_name = current_platform.get_device_name()
print(f"Device: {device_name}")
print(f"Token counts: {num_tokens_list}")
print(f"TP sizes: {tp_sizes}")
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
benchmark.run(print_data=True)
if __name__ == "__main__":
with torch.inference_mode():
main()

View File

@@ -7,19 +7,29 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "../quantization/w8a8/fp8/common.cuh"
#include "../dispatch_utils.h"
namespace vllm {
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, const uint NUM_THREADS>
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
bool USE_FP8_OUTPUT>
__global__ void merge_attn_states_kernel(
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
output_t* output, float* output_lse, const scalar_t* prefix_output,
const float* prefix_lse, const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens, const uint num_heads,
const uint head_size, const uint prefix_head_stride,
const uint output_head_stride, const uint prefix_num_tokens) {
using pack_128b_t = uint4;
const uint output_head_stride, const uint prefix_num_tokens,
const float* output_scale) {
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
// Outputs store pack_size elements of output_t, which is smaller for FP8.
using input_pack_t = uint4;
using output_pack_t =
std::conditional_t<USE_FP8_OUTPUT,
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
uint4>;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;
@@ -42,15 +52,36 @@ __global__ void merge_attn_states_kernel(
head_idx * output_head_stride;
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
scalar_t* output_head_ptr = output + dst_head_offset;
output_t* output_head_ptr = output + dst_head_offset;
// Pre-invert scale: multiplication is faster than division
float fp8_scale_inv = 1.0f;
if constexpr (USE_FP8_OUTPUT) {
fp8_scale_inv = 1.0f / *output_scale;
}
// If token_idx >= prefix_num_tokens, just copy from suffix
if (token_idx >= prefix_num_tokens) {
if (pack_offset < head_size) {
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
s_out_pack;
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
}
}
if (output_lse != nullptr && pack_idx == 0) {
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
@@ -70,20 +101,34 @@ __global__ void merge_attn_states_kernel(
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
continuing the pipeline then yields NaN. Root cause: with chunked prefill
a batch may be split into two chunks; if a request in that batch has no
prefix hit, every LSE entry for that requests position is -inf, and at
prefix hit, every LSE entry for that request's position is -inf, and at
this moment we merge cross-attention at first. For now we simply emit
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
this problem.
*/
if (std::isinf(max_lse)) {
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
prefix_head_ptr)[pack_offset / pack_size];
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
p_out_pack;
if constexpr (USE_FP8_OUTPUT) {
// Convert prefix values to FP8 (since -inf means no data,
// prefix_output is expected to be zeros)
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
const float val =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
o_out_pack[i] =
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
}
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -101,30 +146,43 @@ __global__ void merge_attn_states_kernel(
const float s_scale = s_se / out_se;
if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;
// Compute merged values in float32
float o_out_f[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f =
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
}
// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
o_out_pack;
// Convert and store
if constexpr (USE_FP8_OUTPUT) {
output_t o_out_pack[pack_size];
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
o_out_f[i], fp8_scale_inv);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] =
*reinterpret_cast<output_pack_t*>(o_out_pack);
} else {
output_pack_t o_out_pack;
#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f[i]);
}
reinterpret_cast<output_pack_t*>(
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
}
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
@@ -151,24 +209,26 @@ __global__ void merge_attn_states_kernel(
} \
}
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT) \
{ \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
USE_FP8_OUTPUT> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
num_heads, head_size, prefix_head_stride, output_head_stride, \
prefix_num_tokens); \
prefix_num_tokens, output_scale_ptr); \
}
/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
* states.
@@ -180,19 +240,23 @@ __global__ void merge_attn_states_kernel(
* is computed by merging prefix_output and suffix_output. For remaining tokens
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
* from suffix_output.
* @param output_scale Optional scalar tensor for FP8 static quantization.
* When provided, output must be FP8 dtype.
*/
template <typename scalar_t>
void merge_attn_states_launcher(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context) {
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint prefix_head_stride = prefix_output.stride(1);
const uint output_head_stride = output.stride(1);
// Thread mapping is based on input BF16 pack_size
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
@@ -208,6 +272,10 @@ void merge_attn_states_launcher(
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
float* output_scale_ptr = nullptr;
if (output_scale.has_value()) {
output_scale_ptr = output_scale.value().data_ptr<float>();
}
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size;
@@ -219,20 +287,44 @@ void merge_attn_states_launcher(
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
if (output_scale.has_value()) {
// FP8 output path - dispatch on output FP8 type
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
});
} else {
// Original BF16/FP16/FP32 output path
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
}
}
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>( \
output, output_lse, prefix_output, prefix_lse, suffix_output, \
suffix_lse, prefill_tokens_with_context); \
suffix_lse, prefill_tokens_with_context, output_scale); \
}
void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context = std::nullopt) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse,
std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale) {
if (output_scale.has_value()) {
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
"output must be FP8 when output_scale is provided, got: ",
output.scalar_type());
} else {
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
"output dtype (", output.scalar_type(),
") must match prefix_output dtype (",
prefix_output.scalar_type(), ") when output_scale is not set");
}
// Always dispatch on prefix_output (input) dtype
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
CALL_MERGE_ATTN_STATES_LAUNCHER);
}

View File

@@ -57,7 +57,8 @@ void merge_attn_states(
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
const std::optional<int64_t> prefill_tokens_with_context);
const std::optional<int64_t> prefill_tokens_with_context,
const std::optional<torch::Tensor>& output_scale = std::nullopt);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]

View File

@@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse,"
" int!? prefill_tokens_with_context) -> ()");
" int!? prefill_tokens_with_context,"
" Tensor? output_scale=None) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(

View File

@@ -4,7 +4,12 @@
import pytest
import torch
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
from vllm._custom_ops import (
merge_attn_states as merge_attn_states_cuda,
)
from vllm._custom_ops import (
scaled_fp8_quant,
)
from vllm.platforms import current_platform
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states as merge_attn_states_triton,
@@ -21,6 +26,7 @@ def merge_attn_states_torch(
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
prefill_tokens_with_context: int | None = None,
output_scale: torch.Tensor | None = None, # scalar, per-tensor FP8 scale
):
# Apply prefill_tokens_with_context mask if needed
if prefill_tokens_with_context is None:
@@ -49,9 +55,13 @@ def merge_attn_states_torch(
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
output.copy_(
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
output = prefix_output * p_scale * mask + suffix_output * (
s_scale * mask + (1 - mask)
)
if output_scale is not None:
shape = output.shape
output, _ = scaled_fp8_quant(output.float().view(-1, shape[-1]), output_scale)
output = output.view(shape)
return output, output_lse
@@ -102,18 +112,20 @@ def generate_markdown_table():
)
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("output_dtype", DTYPES)
@pytest.mark.parametrize("input_dtype", DTYPES)
@torch.inference_mode()
def test_merge_attn_states(
prefill_tokens_with_context: int | None,
num_tokens: int,
num_query_heads: int,
head_size: int,
output_dtype: torch.dtype,
input_dtype: torch.dtype,
use_fp8: bool,
):
if not current_platform.is_cuda():
pytest.skip(
@@ -125,9 +137,18 @@ def test_merge_attn_states(
NUM_HEADS = num_query_heads
HEAD_SIZE = head_size
# When use_fp8 is set, inputs stay as input_dtype (bf16/fp16/fp32)
# and output becomes FP8.
output_dtype = input_dtype
output_scale = None
if use_fp8:
output_dtype = current_platform.fp8_dtype()
output_scale = torch.tensor([0.05], dtype=torch.float32, device="cuda")
print(
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
f"HEAD_SIZE:{HEAD_SIZE}, input_dtype: {input_dtype}, "
f"output_dtype: {output_dtype}, use_fp8: {use_fp8}, "
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
f"Device: {current_platform.get_device_name()}"
)
@@ -156,10 +177,10 @@ def test_merge_attn_states(
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
)
prefix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
)
suffix_output = torch.randn(
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
)
warmup_times = 2
@@ -183,6 +204,7 @@ def test_merge_attn_states(
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
output_scale,
)
torch.accelerator.synchronize()
@@ -196,6 +218,7 @@ def test_merge_attn_states(
suffix_lse_torch,
output_lse_torch,
prefill_tokens_with_context,
output_scale,
)
end.record()
torch.accelerator.synchronize()
@@ -220,6 +243,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
output_scale,
)
torch.accelerator.synchronize()
@@ -233,6 +257,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_ref_triton,
prefill_tokens_with_context,
output_scale,
)
end.record()
torch.accelerator.synchronize()
@@ -254,6 +279,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
output_scale,
)
torch.accelerator.synchronize()
@@ -267,6 +293,7 @@ def test_merge_attn_states(
suffix_lse,
output_lse_cuda,
prefill_tokens_with_context,
output_scale,
)
end.record()
torch.accelerator.synchronize()
@@ -288,7 +315,19 @@ def test_merge_attn_states(
# Liger Kernel: Efficient Triton Kernels for LLM Training
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
# use rtol = 1e-2 for bfloat16.
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
if use_fp8:
# Compare in dequantized space (multiply back by scale) so that
# absolute differences reflect real precision, not amplified FP8
# quantization steps.
atol, rtol = 1e-1, 1e-1
assert output_scale is not None
scale = output_scale.item()
elif output_dtype == torch.bfloat16:
atol, rtol = 1e-3, 1e-2
scale = 1.0
else:
atol, rtol = 1e-3, 1e-3
scale = 1.0
def diff(a: torch.Tensor, b: torch.Tensor):
max_diff = torch.max(torch.abs(a.float() - b.float()))
@@ -300,16 +339,26 @@ def test_merge_attn_states(
output_ref = output_ref_triton
output_lse_ref = output_lse_ref_triton
torch.testing.assert_close(
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
output_cuda.float() * scale,
output_ref.float() * scale,
atol=atol,
rtol=rtol,
)
print("Output all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
print(
"Output all match, max abs diff (dequantized):"
if use_fp8
else "Output all match, max abs diff:"
)
_diff = diff(output_ref.float() * scale, output_torch.float() * scale)
print(f"(Triton vs Torch) : {_diff}")
_diff = diff(output_torch.float() * scale, output_cuda.float() * scale)
print(f" (CUDA vs Torch) : {_diff}")
_diff = diff(output_ref.float() * scale, output_cuda.float() * scale)
print(f" (CUDA vs Triton): {_diff}")
print("-" * 100)
torch.testing.assert_close(
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
output_lse_cuda.float(), output_lse_ref.float(), atol=atol, rtol=rtol
)
print("Output LSE all match, max abs diff:")
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")

View File

@@ -265,6 +265,7 @@ def merge_attn_states(
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
prefill_tokens_with_context: int | None = None,
output_scale: torch.Tensor | None = None,
) -> None:
torch.ops._C.merge_attn_states(
output,
@@ -274,6 +275,7 @@ def merge_attn_states(
suffix_output,
suffix_lse,
prefill_tokens_with_context,
output_scale,
)

View File

@@ -14,6 +14,7 @@ def merge_attn_states(
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
prefill_tokens_with_context: int | None = None,
output_scale: torch.Tensor | None = None,
) -> None:
"""Merge partial attention outputs from prefix (KV cache) and suffix
(new tokens) into a single output tensor using the log-sum-exp (LSE)
@@ -41,27 +42,37 @@ def merge_attn_states(
>= this value are decode or context-free prefill tokens whose
output is taken directly from suffix_output. If None, all tokens
are treated as having context.
output_scale: Optional scalar tensor for FP8 static quantization.
When provided, output must be FP8 dtype.
"""
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
# does not support FP8 dtype, fallback to use Triton kernel.
def supported_dtypes(o: torch.Tensor) -> bool:
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
# does not support FP8 dtype for inputs, fallback to use Triton kernel.
# However, when output_scale is provided, the inputs are still BF16/FP16
# and the output is FP8 — both CUDA and Triton support this.
# FP8 output requires output_scale to be set.
if output.dtype not in (torch.float32, torch.half, torch.bfloat16):
assert output_scale is not None, (
f"output_scale is required when output is {output.dtype}"
)
def supported_dtypes(prefix: torch.Tensor) -> bool:
return prefix.dtype in [torch.float32, torch.half, torch.bfloat16]
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
# kernel load/store 128b(16 bytes) per memory issue within
# thread. Namely, the headsize(headdim) must be multiple of
# pack_size (float32 -> 4, half/bfloat16 -> 8).
def supported_headdim(o: torch.Tensor) -> bool:
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if o.dtype == torch.float32:
# pack_size based on input dtype (float32 -> 4, half/bfloat16 -> 8).
def supported_headdim(prefix: torch.Tensor) -> bool:
headdim = prefix.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
if prefix.dtype == torch.float32:
return headdim % 4 == 0
return headdim % 8 == 0
if (
current_platform.is_cuda()
and supported_dtypes(output)
and supported_headdim(output)
and supported_dtypes(prefix_output)
and supported_headdim(prefix_output)
):
from vllm._custom_ops import merge_attn_states
@@ -73,9 +84,12 @@ def merge_attn_states(
suffix_lse,
output_lse,
prefill_tokens_with_context,
output_scale,
)
else:
from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.v1.attention.ops.triton_merge_attn_states import (
merge_attn_states,
)
return merge_attn_states(
output,
@@ -85,4 +99,5 @@ def merge_attn_states(
suffix_lse,
output_lse,
prefill_tokens_with_context,
output_scale,
)

View File

@@ -3,8 +3,11 @@
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
float8_info = torch.finfo(current_platform.fp8_dtype())
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
# can be used to combine partial attention results (in the split-KV case)
@@ -16,14 +19,15 @@ def merge_attn_states(
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
prefill_tokens_with_context: int | None = None,
output_scale: torch.Tensor | None = None,
) -> None:
num_tokens = output.shape[0]
num_query_heads = output.shape[1]
head_size = output.shape[2]
padded_head_size = triton.next_power_of_2(head_size)
# We assume the output stride on num_head is not always as same as the
# `suffix_output` and `prefix_output`, as them might be padded by the attention
# backend.
# `suffix_output` and `prefix_output`, as them might be padded by the
# attention backend.
prefix_head_stride = prefix_output.stride(1)
output_head_stride = output.stride(1)
@@ -41,10 +45,12 @@ def merge_attn_states(
suffix_lse,
prefix_head_stride,
output_head_stride,
output_scale,
head_size,
padded_head_size,
output_lse is not None,
prefill_tokens_with_context,
output_scale is not None,
)
@@ -58,10 +64,14 @@ def merge_attn_states_kernel(
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
prefix_head_stride,
output_head_stride,
output_scale, # scale tensor or None
HEAD_SIZE: tl.constexpr,
PADDED_HEAD_SIZE: tl.constexpr,
OUTPUT_LSE: tl.constexpr,
prefill_tokens_with_context: tl.constexpr,
USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
):
token_idx = tl.program_id(0)
num_tokens = tl.num_programs(0)
@@ -87,6 +97,12 @@ def merge_attn_states_kernel(
+ head_arange,
mask=head_mask,
)
if USE_FP8:
s_out = s_out * (1.0 / tl.load(output_scale))
s_out = tl.clamp(s_out, FP8_MIN, FP8_MAX)
s_out = s_out.to(output.dtype.element_ty)
tl.store(
output
+ token_idx * num_heads * output_head_stride
@@ -143,6 +159,12 @@ def merge_attn_states_kernel(
p_scale = p_se / out_se
s_scale = s_se / out_se
out = p_out * p_scale + s_out * s_scale
if USE_FP8:
out = out * (1.0 / tl.load(output_scale))
out = tl.clamp(out, FP8_MIN, FP8_MAX)
out = out.to(output.dtype.element_ty)
tl.store(
output
+ token_idx * num_heads * output_head_stride