diff --git a/benchmarks/fused_kernels/merge_attn_states_benchmarks.py b/benchmarks/fused_kernels/merge_attn_states_benchmarks.py new file mode 100644 index 000000000..26b04299b --- /dev/null +++ b/benchmarks/fused_kernels/merge_attn_states_benchmarks.py @@ -0,0 +1,264 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Benchmark: Fused FP8 output quantization in merge_attn_states + +Compares fused vs unfused approaches for producing FP8-quantized merged +attention output: + 1. Fused CUDA -- single CUDA kernel (merge + FP8 quant) + 2. Fused Triton -- single Triton kernel (merge + FP8 quant) + 3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant + 4. Unfused Triton -- Triton merge + torch.compiled FP8 quant + +Usage: + python benchmarks/fused_kernels/merge_attn_states_benchmarks.py + python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8 + python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16 +""" + +import argparse +import itertools + +import torch + +from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda +from vllm.benchmarks.lib.utils import default_vllm_config +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 +from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape +from vllm.platforms import current_platform +from vllm.triton_utils import triton +from vllm.v1.attention.ops.triton_merge_attn_states import ( + merge_attn_states as merge_attn_states_triton, +) + +# --------------------------------------------------------------------------- +# Configuration defaults +# --------------------------------------------------------------------------- + +NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096] + +# (label, num_heads, head_size) — num_heads is for TP=1 +HEAD_CONFIGS = [ + ("DeepSeek-V3 MLA", 128, 128), + ("Llama-70B", 64, 128), + ("Llama-8B", 32, 128), +] + +TP_SIZES = [1, 2, 4, 8] + +INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16] + +QUANTILES = [0.5, 0.2, 0.8] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def short_dtype(dtype: torch.dtype) -> str: + return str(dtype).removeprefix("torch.") + + +def make_inputs( + num_tokens: int, + num_heads: int, + head_size: int, + dtype: torch.dtype, +): + """Create random prefix/suffix outputs and LSEs.""" + prefix_output = torch.randn( + (num_tokens, num_heads, head_size), dtype=dtype, device="cuda" + ) + suffix_output = torch.randn( + (num_tokens, num_heads, head_size), dtype=dtype, device="cuda" + ) + prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda") + suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda") + # Sprinkle some inf values to exercise edge-case paths + mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05 + prefix_lse[mask] = float("inf") + mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05 + suffix_lse[mask2] = float("inf") + return prefix_output, suffix_output, prefix_lse, suffix_lse + + +def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes): + """Build (num_tokens, num_heads, head_size, dtype_str) config tuples, + applying TP division to num_heads and skipping invalid combos.""" + configs = [] + for (_, nh, hs), nt, dtype, tp in itertools.product( + head_configs, num_tokens_list, input_dtypes, tp_sizes + ): + nh_tp = nh // tp + if nh_tp >= 1: + configs.append((nt, nh_tp, hs, short_dtype(dtype))) + return configs + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark merge_attn_states fused FP8 quantization" + ) + parser.add_argument( + "--num-tokens", + type=int, + nargs="+", + default=None, + help=f"Override token counts (default: {NUM_TOKENS_LIST})", + ) + parser.add_argument( + "--tp", + type=int, + nargs="+", + default=None, + help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})", + ) + parser.add_argument( + "--dtype", + type=str, + nargs="+", + default=None, + help="Input dtypes (e.g. bfloat16 float16 float32). " + f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}", + ) + return parser.parse_args() + + +# --------------------------------------------------------------------------- +# Parse args and build configs before decorators +# --------------------------------------------------------------------------- + +args = parse_args() + +num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST +tp_sizes = args.tp if args.tp else TP_SIZES + +if args.dtype: + from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + + input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype] +else: + input_dtypes = INPUT_DTYPES + +configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes) + +torch._dynamo.config.recompile_limit = 8888 + + +# --------------------------------------------------------------------------- +# Benchmark function +# --------------------------------------------------------------------------- + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_heads", "head_size", "dtype_str"], + x_vals=configs, + line_arg="provider", + line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"], + line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"], + styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")], + ylabel="us", + plot_name="merge_attn_states FP8 (fused vs unfused)", + args={}, + ) +) +@default_vllm_config() +def benchmark(num_tokens, num_heads, head_size, dtype_str, provider): + input_dtype = getattr(torch, dtype_str) + fp8_dtype = current_platform.fp8_dtype() + prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs( + num_tokens, num_heads, head_size, input_dtype + ) + output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda") + + if provider == "fused_cuda": + output = torch.empty( + (num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda" + ) + fn = lambda: merge_attn_states_cuda( + output, + prefix_out, + prefix_lse, + suffix_out, + suffix_lse, + output_scale=output_scale, + ) + elif provider == "fused_triton": + output = torch.empty( + (num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda" + ) + fn = lambda: merge_attn_states_triton( + output, + prefix_out, + prefix_lse, + suffix_out, + suffix_lse, + output_scale=output_scale, + ) + elif provider == "unfused_cuda": + merge_buf = torch.empty( + (num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda" + ) + quant_fp8 = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + column_major_scales=False, + ) + quant_input = merge_buf.view(-1, head_size) + compiled_quant = torch.compile( + quant_fp8.forward_native, fullgraph=True, dynamic=False + ) + + def unfused_fn(): + merge_attn_states_cuda( + merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse + ) + compiled_quant(quant_input, output_scale) + + fn = unfused_fn + else: # unfused_triton + merge_buf = torch.empty( + (num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda" + ) + quant_fp8 = QuantFP8( + static=True, + group_shape=GroupShape.PER_TENSOR, + column_major_scales=False, + ) + quant_input = merge_buf.view(-1, head_size) + compiled_quant = torch.compile( + quant_fp8.forward_native, fullgraph=True, dynamic=False + ) + + def unfused_fn(): + merge_attn_states_triton( + merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse + ) + compiled_quant(quant_input, output_scale) + + fn = unfused_fn + + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES) + return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + device_name = current_platform.get_device_name() + print(f"Device: {device_name}") + print(f"Token counts: {num_tokens_list}") + print(f"TP sizes: {tp_sizes}") + print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}") + print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}") + benchmark.run(print_data=True) + + +if __name__ == "__main__": + with torch.inference_mode(): + main() diff --git a/csrc/attention/merge_attn_states.cu b/csrc/attention/merge_attn_states.cu index f6c1bf617..75f066e80 100644 --- a/csrc/attention/merge_attn_states.cu +++ b/csrc/attention/merge_attn_states.cu @@ -7,19 +7,29 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "../quantization/w8a8/fp8/common.cuh" +#include "../dispatch_utils.h" namespace vllm { // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // can be used to combine partial attention results (in the split-KV case) -template +template __global__ void merge_attn_states_kernel( - scalar_t* output, float* output_lse, const scalar_t* prefix_output, + output_t* output, float* output_lse, const scalar_t* prefix_output, const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, const uint head_size, const uint prefix_head_stride, - const uint output_head_stride, const uint prefix_num_tokens) { - using pack_128b_t = uint4; + const uint output_head_stride, const uint prefix_num_tokens, + const float* output_scale) { + // Inputs always load 128-bit packs (pack_size elements of scalar_t). + // Outputs store pack_size elements of output_t, which is smaller for FP8. + using input_pack_t = uint4; + using output_pack_t = + std::conditional_t, + uint4>; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; @@ -42,15 +52,36 @@ __global__ void merge_attn_states_kernel( head_idx * output_head_stride; const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; - scalar_t* output_head_ptr = output + dst_head_offset; + output_t* output_head_ptr = output + dst_head_offset; + + // Pre-invert scale: multiplication is faster than division + float fp8_scale_inv = 1.0f; + if constexpr (USE_FP8_OUTPUT) { + fp8_scale_inv = 1.0f / *output_scale; + } // If token_idx >= prefix_num_tokens, just copy from suffix if (token_idx >= prefix_num_tokens) { if (pack_offset < head_size) { - pack_128b_t s_out_pack = reinterpret_cast( + input_pack_t s_out_pack = reinterpret_cast( suffix_head_ptr)[pack_offset / pack_size]; - reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = - s_out_pack; + + if constexpr (USE_FP8_OUTPUT) { + output_t o_out_pack[pack_size]; +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + const float val = + vllm::to_float(reinterpret_cast(&s_out_pack)[i]); + o_out_pack[i] = + vllm::scaled_fp8_conversion(val, fp8_scale_inv); + } + reinterpret_cast( + output_head_ptr)[pack_offset / pack_size] = + *reinterpret_cast(o_out_pack); + } else { + reinterpret_cast( + output_head_ptr)[pack_offset / pack_size] = s_out_pack; + } } if (output_lse != nullptr && pack_idx == 0) { float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; @@ -70,20 +101,34 @@ __global__ void merge_attn_states_kernel( /* In certain edge cases, MLA can produce p_lse = s_lse = -inf; continuing the pipeline then yields NaN. Root cause: with chunked prefill a batch may be split into two chunks; if a request in that batch has no - prefix hit, every LSE entry for that request’s position is -inf, and at + prefix hit, every LSE entry for that request's position is -inf, and at this moment we merge cross-attention at first. For now we simply emit prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix this problem. */ if (std::isinf(max_lse)) { if (pack_offset < head_size) { - // Pack 128b load - pack_128b_t p_out_pack = reinterpret_cast( + input_pack_t p_out_pack = reinterpret_cast( prefix_head_ptr)[pack_offset / pack_size]; - // Pack 128b storage - reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = - p_out_pack; + if constexpr (USE_FP8_OUTPUT) { + // Convert prefix values to FP8 (since -inf means no data, + // prefix_output is expected to be zeros) + output_t o_out_pack[pack_size]; +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + const float val = + vllm::to_float(reinterpret_cast(&p_out_pack)[i]); + o_out_pack[i] = + vllm::scaled_fp8_conversion(val, fp8_scale_inv); + } + reinterpret_cast( + output_head_ptr)[pack_offset / pack_size] = + *reinterpret_cast(o_out_pack); + } else { + reinterpret_cast( + output_head_ptr)[pack_offset / pack_size] = p_out_pack; + } } // We only need to write to output_lse once per head. if (output_lse != nullptr && pack_idx == 0) { @@ -101,30 +146,43 @@ __global__ void merge_attn_states_kernel( const float s_scale = s_se / out_se; if (pack_offset < head_size) { - // Pack 128b load - pack_128b_t p_out_pack = reinterpret_cast( + input_pack_t p_out_pack = reinterpret_cast( prefix_head_ptr)[pack_offset / pack_size]; - pack_128b_t s_out_pack = reinterpret_cast( + input_pack_t s_out_pack = reinterpret_cast( suffix_head_ptr)[pack_offset / pack_size]; - pack_128b_t o_out_pack; + // Compute merged values in float32 + float o_out_f[pack_size]; #pragma unroll for (uint i = 0; i < pack_size; ++i) { - // Always use float for FMA to keep high precision. - // half(uint16_t), bfloat16, float -> float. const float p_out_f = vllm::to_float(reinterpret_cast(&p_out_pack)[i]); const float s_out_f = vllm::to_float(reinterpret_cast(&s_out_pack)[i]); - // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) - const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); - // float -> half(uint16_t), bfloat16, float. - vllm::from_float(reinterpret_cast(&o_out_pack)[i], o_out_f); + o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale); } - // Pack 128b storage - reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = - o_out_pack; + // Convert and store + if constexpr (USE_FP8_OUTPUT) { + output_t o_out_pack[pack_size]; +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + o_out_pack[i] = vllm::scaled_fp8_conversion( + o_out_f[i], fp8_scale_inv); + } + reinterpret_cast( + output_head_ptr)[pack_offset / pack_size] = + *reinterpret_cast(o_out_pack); + } else { + output_pack_t o_out_pack; +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + vllm::from_float(reinterpret_cast(&o_out_pack)[i], + o_out_f[i]); + } + reinterpret_cast( + output_head_ptr)[pack_offset / pack_size] = o_out_pack; + } } // We only need to write to output_lse once per head. if (output_lse != nullptr && pack_idx == 0) { @@ -151,24 +209,26 @@ __global__ void merge_attn_states_kernel( } \ } -#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \ + USE_FP8_OUTPUT) \ { \ - vllm::merge_attn_states_kernel \ + vllm::merge_attn_states_kernel \ <<>>( \ - reinterpret_cast(output.data_ptr()), output_lse_ptr, \ + reinterpret_cast(output.data_ptr()), output_lse_ptr, \ reinterpret_cast(prefix_output.data_ptr()), \ reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ num_heads, head_size, prefix_head_stride, output_head_stride, \ - prefix_num_tokens); \ + prefix_num_tokens, output_scale_ptr); \ } /*@brief Merges the attention states from prefix and suffix * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d * * @param output [n,h,d] The output tensor to store the merged attention states. - * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. + * @param output_lse [h,n] Optional tensor to store the log-sum-exp values. * @param prefix_output [n,h,d] The prefix attention states. * @param prefix_lse [h,n] The log-sum-exp values for the prefix attention * states. @@ -180,19 +240,23 @@ __global__ void merge_attn_states_kernel( * is computed by merging prefix_output and suffix_output. For remaining tokens * (prefill_tokens_with_context <= token_idx < n), output is copied directly * from suffix_output. + * @param output_scale Optional scalar tensor for FP8 static quantization. + * When provided, output must be FP8 dtype. */ template void merge_attn_states_launcher( torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, - const std::optional prefill_tokens_with_context) { + const std::optional prefill_tokens_with_context, + const std::optional& output_scale) { constexpr uint NUM_THREADS = 128; const uint num_tokens = output.size(0); const uint num_heads = output.size(1); const uint head_size = output.size(2); const uint prefix_head_stride = prefix_output.stride(1); const uint output_head_stride = output.stride(1); + // Thread mapping is based on input BF16 pack_size const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); @@ -208,6 +272,10 @@ void merge_attn_states_launcher( if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); } + float* output_scale_ptr = nullptr; + if (output_scale.has_value()) { + output_scale_ptr = output_scale.value().data_ptr(); + } // Process one pack elements per thread. for float, the // pack_size is 4 for half/bf16, the pack_size is 8. const uint threads_per_head = head_size / pack_size; @@ -219,20 +287,44 @@ void merge_attn_states_launcher( const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); auto stream = at::cuda::getCurrentCUDAStream(); - LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); + if (output_scale.has_value()) { + // FP8 output path - dispatch on output FP8 type + VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] { + LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true); + }); + } else { + // Original BF16/FP16/FP32 output path + LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false); + } } #define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ { \ merge_attn_states_launcher( \ output, output_lse, prefix_output, prefix_lse, suffix_output, \ - suffix_lse, prefill_tokens_with_context); \ + suffix_lse, prefill_tokens_with_context, output_scale); \ } -void merge_attn_states( - torch::Tensor& output, std::optional output_lse, - const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, - const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, - std::optional prefill_tokens_with_context = std::nullopt) { - DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); +void merge_attn_states(torch::Tensor& output, + std::optional output_lse, + const torch::Tensor& prefix_output, + const torch::Tensor& prefix_lse, + const torch::Tensor& suffix_output, + const torch::Tensor& suffix_lse, + std::optional prefill_tokens_with_context, + const std::optional& output_scale) { + if (output_scale.has_value()) { + TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn || + output.scalar_type() == at::ScalarType::Float8_e4m3fnuz, + "output must be FP8 when output_scale is provided, got: ", + output.scalar_type()); + } else { + TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(), + "output dtype (", output.scalar_type(), + ") must match prefix_output dtype (", + prefix_output.scalar_type(), ") when output_scale is not set"); + } + // Always dispatch on prefix_output (input) dtype + DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(), + CALL_MERGE_ATTN_STATES_LAUNCHER); } diff --git a/csrc/ops.h b/csrc/ops.h index 20351a3e4..cc5842223 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -57,7 +57,8 @@ void merge_attn_states( torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, - const std::optional prefill_tokens_with_context); + const std::optional prefill_tokens_with_context, + const std::optional& output_scale = std::nullopt); #ifndef USE_ROCM void convert_vertical_slash_indexes( torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 0354df666..1beab5257 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor prefix_lse," " Tensor suffix_output," " Tensor suffix_lse," - " int!? prefill_tokens_with_context) -> ()"); + " int!? prefill_tokens_with_context," + " Tensor? output_scale=None) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); #ifndef USE_ROCM ops.def( diff --git a/tests/kernels/attention/test_merge_attn_states.py b/tests/kernels/attention/test_merge_attn_states.py index c1b71d93e..40af84887 100644 --- a/tests/kernels/attention/test_merge_attn_states.py +++ b/tests/kernels/attention/test_merge_attn_states.py @@ -4,7 +4,12 @@ import pytest import torch -from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda +from vllm._custom_ops import ( + merge_attn_states as merge_attn_states_cuda, +) +from vllm._custom_ops import ( + scaled_fp8_quant, +) from vllm.platforms import current_platform from vllm.v1.attention.ops.triton_merge_attn_states import ( merge_attn_states as merge_attn_states_triton, @@ -21,6 +26,7 @@ def merge_attn_states_torch( suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS] output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS] prefill_tokens_with_context: int | None = None, + output_scale: torch.Tensor | None = None, # scalar, per-tensor FP8 scale ): # Apply prefill_tokens_with_context mask if needed if prefill_tokens_with_context is None: @@ -49,9 +55,13 @@ def merge_attn_states_torch( s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS] p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] - output.copy_( - prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask)) + output = prefix_output * p_scale * mask + suffix_output * ( + s_scale * mask + (1 - mask) ) + if output_scale is not None: + shape = output.shape + output, _ = scaled_fp8_quant(output.float().view(-1, shape[-1]), output_scale) + output = output.view(shape) return output, output_lse @@ -102,18 +112,20 @@ def generate_markdown_table(): ) +@pytest.mark.parametrize("use_fp8", [False, True]) @pytest.mark.parametrize("prefill_tokens_with_context", [None, 128]) @pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) @pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) -@pytest.mark.parametrize("output_dtype", DTYPES) +@pytest.mark.parametrize("input_dtype", DTYPES) @torch.inference_mode() def test_merge_attn_states( prefill_tokens_with_context: int | None, num_tokens: int, num_query_heads: int, head_size: int, - output_dtype: torch.dtype, + input_dtype: torch.dtype, + use_fp8: bool, ): if not current_platform.is_cuda(): pytest.skip( @@ -125,9 +137,18 @@ def test_merge_attn_states( NUM_HEADS = num_query_heads HEAD_SIZE = head_size + # When use_fp8 is set, inputs stay as input_dtype (bf16/fp16/fp32) + # and output becomes FP8. + output_dtype = input_dtype + output_scale = None + if use_fp8: + output_dtype = current_platform.fp8_dtype() + output_scale = torch.tensor([0.05], dtype=torch.float32, device="cuda") + print( f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " - f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"HEAD_SIZE:{HEAD_SIZE}, input_dtype: {input_dtype}, " + f"output_dtype: {output_dtype}, use_fp8: {use_fp8}, " f"prefill_tokens_with_context: {prefill_tokens_with_context}, " f"Device: {current_platform.get_device_name()}" ) @@ -156,10 +177,10 @@ def test_merge_attn_states( (NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda" ) prefix_output = torch.randn( - (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda" ) suffix_output = torch.randn( - (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda" ) warmup_times = 2 @@ -183,6 +204,7 @@ def test_merge_attn_states( suffix_lse_torch, output_lse_torch, prefill_tokens_with_context, + output_scale, ) torch.accelerator.synchronize() @@ -196,6 +218,7 @@ def test_merge_attn_states( suffix_lse_torch, output_lse_torch, prefill_tokens_with_context, + output_scale, ) end.record() torch.accelerator.synchronize() @@ -220,6 +243,7 @@ def test_merge_attn_states( suffix_lse, output_lse_ref_triton, prefill_tokens_with_context, + output_scale, ) torch.accelerator.synchronize() @@ -233,6 +257,7 @@ def test_merge_attn_states( suffix_lse, output_lse_ref_triton, prefill_tokens_with_context, + output_scale, ) end.record() torch.accelerator.synchronize() @@ -254,6 +279,7 @@ def test_merge_attn_states( suffix_lse, output_lse_cuda, prefill_tokens_with_context, + output_scale, ) torch.accelerator.synchronize() @@ -267,6 +293,7 @@ def test_merge_attn_states( suffix_lse, output_lse_cuda, prefill_tokens_with_context, + output_scale, ) end.record() torch.accelerator.synchronize() @@ -288,7 +315,19 @@ def test_merge_attn_states( # Liger Kernel: Efficient Triton Kernels for LLM Training # https://arxiv.org/pdf/2410.10989, 3.3 Correctness # use rtol = 1e-2 for bfloat16. - rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3 + if use_fp8: + # Compare in dequantized space (multiply back by scale) so that + # absolute differences reflect real precision, not amplified FP8 + # quantization steps. + atol, rtol = 1e-1, 1e-1 + assert output_scale is not None + scale = output_scale.item() + elif output_dtype == torch.bfloat16: + atol, rtol = 1e-3, 1e-2 + scale = 1.0 + else: + atol, rtol = 1e-3, 1e-3 + scale = 1.0 def diff(a: torch.Tensor, b: torch.Tensor): max_diff = torch.max(torch.abs(a.float() - b.float())) @@ -300,16 +339,26 @@ def test_merge_attn_states( output_ref = output_ref_triton output_lse_ref = output_lse_ref_triton torch.testing.assert_close( - output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol + output_cuda.float() * scale, + output_ref.float() * scale, + atol=atol, + rtol=rtol, ) - print("Output all match, max abs diff:") - print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") - print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}") - print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}") + print( + "Output all match, max abs diff (dequantized):" + if use_fp8 + else "Output all match, max abs diff:" + ) + _diff = diff(output_ref.float() * scale, output_torch.float() * scale) + print(f"(Triton vs Torch) : {_diff}") + _diff = diff(output_torch.float() * scale, output_cuda.float() * scale) + print(f" (CUDA vs Torch) : {_diff}") + _diff = diff(output_ref.float() * scale, output_cuda.float() * scale) + print(f" (CUDA vs Triton): {_diff}") print("-" * 100) torch.testing.assert_close( - output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol + output_lse_cuda.float(), output_lse_ref.float(), atol=atol, rtol=rtol ) print("Output LSE all match, max abs diff:") print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9cc023138..65eca3208 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -265,6 +265,7 @@ def merge_attn_states( suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, prefill_tokens_with_context: int | None = None, + output_scale: torch.Tensor | None = None, ) -> None: torch.ops._C.merge_attn_states( output, @@ -274,6 +275,7 @@ def merge_attn_states( suffix_output, suffix_lse, prefill_tokens_with_context, + output_scale, ) diff --git a/vllm/v1/attention/ops/merge_attn_states.py b/vllm/v1/attention/ops/merge_attn_states.py index 270f65d5e..cf4338fb1 100644 --- a/vllm/v1/attention/ops/merge_attn_states.py +++ b/vllm/v1/attention/ops/merge_attn_states.py @@ -14,6 +14,7 @@ def merge_attn_states( suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, prefill_tokens_with_context: int | None = None, + output_scale: torch.Tensor | None = None, ) -> None: """Merge partial attention outputs from prefix (KV cache) and suffix (new tokens) into a single output tensor using the log-sum-exp (LSE) @@ -41,27 +42,37 @@ def merge_attn_states( >= this value are decode or context-free prefill tokens whose output is taken directly from suffix_output. If None, all tokens are treated as having context. + output_scale: Optional scalar tensor for FP8 static quantization. + When provided, output must be FP8 dtype. """ # NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel - # does not support FP8 dtype, fallback to use Triton kernel. - def supported_dtypes(o: torch.Tensor) -> bool: - return o.dtype in [torch.float32, torch.half, torch.bfloat16] + # does not support FP8 dtype for inputs, fallback to use Triton kernel. + # However, when output_scale is provided, the inputs are still BF16/FP16 + # and the output is FP8 — both CUDA and Triton support this. + # FP8 output requires output_scale to be set. + if output.dtype not in (torch.float32, torch.half, torch.bfloat16): + assert output_scale is not None, ( + f"output_scale is required when output is {output.dtype}" + ) + + def supported_dtypes(prefix: torch.Tensor) -> bool: + return prefix.dtype in [torch.float32, torch.half, torch.bfloat16] # NOTE(DefTruth): Currently, custom merge_attn_states CUDA # kernel load/store 128b(16 bytes) per memory issue within # thread. Namely, the headsize(headdim) must be multiple of - # pack_size (float32 -> 4, half/bfloat16 -> 8). - def supported_headdim(o: torch.Tensor) -> bool: - headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] - if o.dtype == torch.float32: + # pack_size based on input dtype (float32 -> 4, half/bfloat16 -> 8). + def supported_headdim(prefix: torch.Tensor) -> bool: + headdim = prefix.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + if prefix.dtype == torch.float32: return headdim % 4 == 0 return headdim % 8 == 0 if ( current_platform.is_cuda() - and supported_dtypes(output) - and supported_headdim(output) + and supported_dtypes(prefix_output) + and supported_headdim(prefix_output) ): from vllm._custom_ops import merge_attn_states @@ -73,9 +84,12 @@ def merge_attn_states( suffix_lse, output_lse, prefill_tokens_with_context, + output_scale, ) else: - from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states + from vllm.v1.attention.ops.triton_merge_attn_states import ( + merge_attn_states, + ) return merge_attn_states( output, @@ -85,4 +99,5 @@ def merge_attn_states( suffix_lse, output_lse, prefill_tokens_with_context, + output_scale, ) diff --git a/vllm/v1/attention/ops/triton_merge_attn_states.py b/vllm/v1/attention/ops/triton_merge_attn_states.py index f5b4fbe0b..14a52ada9 100644 --- a/vllm/v1/attention/ops/triton_merge_attn_states.py +++ b/vllm/v1/attention/ops/triton_merge_attn_states.py @@ -3,8 +3,11 @@ import torch +from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +float8_info = torch.finfo(current_platform.fp8_dtype()) + # Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 # can be used to combine partial attention results (in the split-KV case) @@ -16,14 +19,15 @@ def merge_attn_states( suffix_lse: torch.Tensor, output_lse: torch.Tensor | None = None, prefill_tokens_with_context: int | None = None, + output_scale: torch.Tensor | None = None, ) -> None: num_tokens = output.shape[0] num_query_heads = output.shape[1] head_size = output.shape[2] padded_head_size = triton.next_power_of_2(head_size) # We assume the output stride on num_head is not always as same as the - # `suffix_output` and `prefix_output`, as them might be padded by the attention - # backend. + # `suffix_output` and `prefix_output`, as them might be padded by the + # attention backend. prefix_head_stride = prefix_output.stride(1) output_head_stride = output.stride(1) @@ -41,10 +45,12 @@ def merge_attn_states( suffix_lse, prefix_head_stride, output_head_stride, + output_scale, head_size, padded_head_size, output_lse is not None, prefill_tokens_with_context, + output_scale is not None, ) @@ -58,10 +64,14 @@ def merge_attn_states_kernel( suffix_lse, # [NUM_HEADS, NUM_TOKENS] prefix_head_stride, output_head_stride, + output_scale, # scale tensor or None HEAD_SIZE: tl.constexpr, PADDED_HEAD_SIZE: tl.constexpr, OUTPUT_LSE: tl.constexpr, prefill_tokens_with_context: tl.constexpr, + USE_FP8: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): token_idx = tl.program_id(0) num_tokens = tl.num_programs(0) @@ -87,6 +97,12 @@ def merge_attn_states_kernel( + head_arange, mask=head_mask, ) + + if USE_FP8: + s_out = s_out * (1.0 / tl.load(output_scale)) + s_out = tl.clamp(s_out, FP8_MIN, FP8_MAX) + s_out = s_out.to(output.dtype.element_ty) + tl.store( output + token_idx * num_heads * output_head_stride @@ -143,6 +159,12 @@ def merge_attn_states_kernel( p_scale = p_se / out_se s_scale = s_se / out_se out = p_out * p_scale + s_out * s_scale + + if USE_FP8: + out = out * (1.0 / tl.load(output_scale)) + out = tl.clamp(out, FP8_MIN, FP8_MAX) + out = out.to(output.dtype.element_ty) + tl.store( output + token_idx * num_heads * output_head_stride