#include #include #include #include #include #include #include "attention_dtypes.h" #include "attention_utils.cuh" #include "../quantization/w8a8/fp8/common.cuh" #include "../dispatch_utils.h" namespace vllm { // Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 // can be used to combine partial attention results (in the split-KV case) template __global__ void merge_attn_states_kernel( output_t* output, float* output_lse, const scalar_t* prefix_output, const float* prefix_lse, const scalar_t* suffix_output, const float* suffix_lse, const uint num_tokens, const uint num_heads, const uint head_size, const uint prefix_head_stride, const uint output_head_stride, const uint prefix_num_tokens, const float* output_scale) { // Inputs always load 128-bit packs (pack_size elements of scalar_t). // Outputs store pack_size elements of output_t, which is smaller for FP8. using input_pack_t = uint4; using output_pack_t = std::conditional_t, uint4>; const uint pack_size = 16 / sizeof(scalar_t); const uint threads_per_head = head_size / pack_size; const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; const uint token_head_threads = num_tokens * num_heads * threads_per_head; if (global_idx >= token_head_threads) return; // global_idx -> token_idx + head_idx + pack_idx const uint token_head_idx = global_idx / threads_per_head; const uint pack_idx = global_idx % threads_per_head; const uint token_idx = token_head_idx / num_heads; const uint head_idx = token_head_idx % num_heads; const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. const uint src_head_offset = token_idx * num_heads * prefix_head_stride + head_idx * prefix_head_stride; const uint dst_head_offset = token_idx * num_heads * output_head_stride + head_idx * output_head_stride; const scalar_t* prefix_head_ptr = prefix_output + src_head_offset; const scalar_t* suffix_head_ptr = suffix_output + src_head_offset; output_t* output_head_ptr = output + dst_head_offset; // Pre-invert scale: multiplication is faster than division float fp8_scale_inv = 1.0f; if constexpr (USE_FP8_OUTPUT) { fp8_scale_inv = 1.0f / *output_scale; } // If token_idx >= prefix_num_tokens, just copy from suffix if (token_idx >= prefix_num_tokens) { if (pack_offset < head_size) { input_pack_t s_out_pack = reinterpret_cast( suffix_head_ptr)[pack_offset / pack_size]; if constexpr (USE_FP8_OUTPUT) { output_t o_out_pack[pack_size]; #pragma unroll for (uint i = 0; i < pack_size; ++i) { const float val = vllm::to_float(reinterpret_cast(&s_out_pack)[i]); o_out_pack[i] = vllm::scaled_fp8_conversion(val, fp8_scale_inv); } reinterpret_cast( output_head_ptr)[pack_offset / pack_size] = *reinterpret_cast(o_out_pack); } else { reinterpret_cast( output_head_ptr)[pack_offset / pack_size] = s_out_pack; } } if (output_lse != nullptr && pack_idx == 0) { float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; output_lse[head_idx * num_tokens + token_idx] = s_lse; } return; } // For tokens within prefix range, merge prefix and suffix float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; const float max_lse = fmaxf(p_lse, s_lse); /* In certain edge cases, MLA can produce p_lse = s_lse = -inf; continuing the pipeline then yields NaN. Root cause: with chunked prefill a batch may be split into two chunks; if a request in that batch has no prefix hit, every LSE entry for that request's position is -inf, and at this moment we merge cross-attention at first. For now we simply emit prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix this problem. */ if (std::isinf(max_lse)) { if (pack_offset < head_size) { input_pack_t p_out_pack = reinterpret_cast( prefix_head_ptr)[pack_offset / pack_size]; if constexpr (USE_FP8_OUTPUT) { // Convert prefix values to FP8 (since -inf means no data, // prefix_output is expected to be zeros) output_t o_out_pack[pack_size]; #pragma unroll for (uint i = 0; i < pack_size; ++i) { const float val = vllm::to_float(reinterpret_cast(&p_out_pack)[i]); o_out_pack[i] = vllm::scaled_fp8_conversion(val, fp8_scale_inv); } reinterpret_cast( output_head_ptr)[pack_offset / pack_size] = *reinterpret_cast(o_out_pack); } else { reinterpret_cast( output_head_ptr)[pack_offset / pack_size] = p_out_pack; } } // We only need to write to output_lse once per head. if (output_lse != nullptr && pack_idx == 0) { output_lse[head_idx * num_tokens + token_idx] = max_lse; } return; } p_lse = p_lse - max_lse; s_lse = s_lse - max_lse; const float p_se = expf(p_lse); const float s_se = expf(s_lse); const float out_se = p_se + s_se; const float p_scale = p_se / out_se; const float s_scale = s_se / out_se; if (pack_offset < head_size) { input_pack_t p_out_pack = reinterpret_cast( prefix_head_ptr)[pack_offset / pack_size]; input_pack_t s_out_pack = reinterpret_cast( suffix_head_ptr)[pack_offset / pack_size]; // Compute merged values in float32 float o_out_f[pack_size]; #pragma unroll for (uint i = 0; i < pack_size; ++i) { const float p_out_f = vllm::to_float(reinterpret_cast(&p_out_pack)[i]); const float s_out_f = vllm::to_float(reinterpret_cast(&s_out_pack)[i]); o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale); } // Convert and store if constexpr (USE_FP8_OUTPUT) { output_t o_out_pack[pack_size]; #pragma unroll for (uint i = 0; i < pack_size; ++i) { o_out_pack[i] = vllm::scaled_fp8_conversion( o_out_f[i], fp8_scale_inv); } reinterpret_cast( output_head_ptr)[pack_offset / pack_size] = *reinterpret_cast(o_out_pack); } else { output_pack_t o_out_pack; #pragma unroll for (uint i = 0; i < pack_size; ++i) { vllm::from_float(reinterpret_cast(&o_out_pack)[i], o_out_f[i]); } reinterpret_cast( output_head_ptr)[pack_offset / pack_size] = o_out_pack; } } // We only need to write to output_lse once per head. if (output_lse != nullptr && pack_idx == 0) { float out_lse = logf(out_se) + max_lse; output_lse[head_idx * num_tokens + token_idx] = out_lse; } } } // namespace vllm // The following macro is used to dispatch the conversion function based on // the output data type. The FN is a macro that calls a function with // template. #define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ { \ if (scalar_dtype == at::ScalarType::Float) { \ fn(float); \ } else if (scalar_dtype == at::ScalarType::Half) { \ fn(uint16_t); \ } else if (scalar_dtype == at::ScalarType::BFloat16) { \ fn(__nv_bfloat16); \ } else { \ TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ } \ } #define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \ USE_FP8_OUTPUT) \ { \ vllm::merge_attn_states_kernel \ <<>>( \ reinterpret_cast(output.data_ptr()), output_lse_ptr, \ reinterpret_cast(prefix_output.data_ptr()), \ reinterpret_cast(prefix_lse.data_ptr()), \ reinterpret_cast(suffix_output.data_ptr()), \ reinterpret_cast(suffix_lse.data_ptr()), num_tokens, \ num_heads, head_size, prefix_head_stride, output_head_stride, \ prefix_num_tokens, output_scale_ptr); \ } /*@brief Merges the attention states from prefix and suffix * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d * * @param output [n,h,d] The output tensor to store the merged attention states. * @param output_lse [h,n] Optional tensor to store the log-sum-exp values. * @param prefix_output [n,h,d] The prefix attention states. * @param prefix_lse [h,n] The log-sum-exp values for the prefix attention * states. * @param suffix_output [n,h,d] The suffix attention states. * @param suffix_lse [h,n] The log-sum-exp values for the suffix attention * states. * @param prefill_tokens_with_context Number of prefill tokens with context * For the first p tokens (0 <= token_idx < prefill_tokens_with_context), output * is computed by merging prefix_output and suffix_output. For remaining tokens * (prefill_tokens_with_context <= token_idx < n), output is copied directly * from suffix_output. * @param output_scale Optional scalar tensor for FP8 static quantization. * When provided, output must be FP8 dtype. */ template void merge_attn_states_launcher( torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, const std::optional prefill_tokens_with_context, const std::optional& output_scale) { constexpr uint NUM_THREADS = 128; const uint num_tokens = output.size(0); const uint num_heads = output.size(1); const uint head_size = output.size(2); const uint prefix_head_stride = prefix_output.stride(1); const uint output_head_stride = output.stride(1); // Thread mapping is based on input BF16 pack_size const uint pack_size = 16 / sizeof(scalar_t); TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); const uint prefix_num_tokens = prefill_tokens_with_context.has_value() ? static_cast(prefill_tokens_with_context.value()) : num_tokens; TORCH_CHECK(prefix_num_tokens <= num_tokens, "prefix_num_tokens must be <= num_tokens"); float* output_lse_ptr = nullptr; if (output_lse.has_value()) { output_lse_ptr = output_lse.value().data_ptr(); } float* output_scale_ptr = nullptr; if (output_scale.has_value()) { output_scale_ptr = output_scale.value().data_ptr(); } // Process one pack elements per thread. for float, the // pack_size is 4 for half/bf16, the pack_size is 8. const uint threads_per_head = head_size / pack_size; const uint total_threads = num_tokens * num_heads * threads_per_head; dim3 block(NUM_THREADS); dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device()); auto stream = at::cuda::getCurrentCUDAStream(); if (output_scale.has_value()) { // FP8 output path - dispatch on output FP8 type VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] { LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true); }); } else { // Original BF16/FP16/FP32 output path LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false); } } #define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ { \ merge_attn_states_launcher( \ output, output_lse, prefix_output, prefix_lse, suffix_output, \ suffix_lse, prefill_tokens_with_context, output_scale); \ } void merge_attn_states(torch::Tensor& output, std::optional output_lse, const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse, std::optional prefill_tokens_with_context, const std::optional& output_scale) { if (output_scale.has_value()) { TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn || output.scalar_type() == at::ScalarType::Float8_e4m3fnuz, "output must be FP8 when output_scale is provided, got: ", output.scalar_type()); } else { TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(), "output dtype (", output.scalar_type(), ") must match prefix_output dtype (", prefix_output.scalar_type(), ") when output_scale is not set"); } // Always dispatch on prefix_output (input) dtype DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); }