[Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
This commit is contained in:
@@ -7,19 +7,29 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "../quantization/w8a8/fp8/common.cuh"
|
||||
#include "../dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
|
||||
bool USE_FP8_OUTPUT>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
output_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size, const uint prefix_head_stride,
|
||||
const uint output_head_stride, const uint prefix_num_tokens) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint output_head_stride, const uint prefix_num_tokens,
|
||||
const float* output_scale) {
|
||||
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
|
||||
// Outputs store pack_size elements of output_t, which is smaller for FP8.
|
||||
using input_pack_t = uint4;
|
||||
using output_pack_t =
|
||||
std::conditional_t<USE_FP8_OUTPUT,
|
||||
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
|
||||
uint4>;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
@@ -42,15 +52,36 @@ __global__ void merge_attn_states_kernel(
|
||||
head_idx * output_head_stride;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||
output_t* output_head_ptr = output + dst_head_offset;
|
||||
|
||||
// Pre-invert scale: multiplication is faster than division
|
||||
float fp8_scale_inv = 1.0f;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
fp8_scale_inv = 1.0f / *output_scale;
|
||||
}
|
||||
|
||||
// If token_idx >= prefix_num_tokens, just copy from suffix
|
||||
if (token_idx >= prefix_num_tokens) {
|
||||
if (pack_offset < head_size) {
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
s_out_pack;
|
||||
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
|
||||
}
|
||||
}
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
@@ -70,20 +101,34 @@ __global__ void merge_attn_states_kernel(
|
||||
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
|
||||
continuing the pipeline then yields NaN. Root cause: with chunked prefill
|
||||
a batch may be split into two chunks; if a request in that batch has no
|
||||
prefix hit, every LSE entry for that request’s position is -inf, and at
|
||||
prefix hit, every LSE entry for that request's position is -inf, and at
|
||||
this moment we merge cross-attention at first. For now we simply emit
|
||||
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
|
||||
this problem.
|
||||
*/
|
||||
if (std::isinf(max_lse)) {
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
p_out_pack;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
// Convert prefix values to FP8 (since -inf means no data,
|
||||
// prefix_output is expected to be zeros)
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -101,30 +146,43 @@ __global__ void merge_attn_states_kernel(
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
// Compute merged values in float32
|
||||
float o_out_f[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
o_out_pack;
|
||||
// Convert and store
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
|
||||
o_out_f[i], fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
output_pack_t o_out_pack;
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
|
||||
o_out_f[i]);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -151,24 +209,26 @@ __global__ void merge_attn_states_kernel(
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride, \
|
||||
prefix_num_tokens); \
|
||||
prefix_num_tokens, output_scale_ptr); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
@@ -180,19 +240,23 @@ __global__ void merge_attn_states_kernel(
|
||||
* is computed by merging prefix_output and suffix_output. For remaining tokens
|
||||
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
|
||||
* from suffix_output.
|
||||
* @param output_scale Optional scalar tensor for FP8 static quantization.
|
||||
* When provided, output must be FP8 dtype.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context) {
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint prefix_head_stride = prefix_output.stride(1);
|
||||
const uint output_head_stride = output.stride(1);
|
||||
// Thread mapping is based on input BF16 pack_size
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
@@ -208,6 +272,10 @@ void merge_attn_states_launcher(
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
float* output_scale_ptr = nullptr;
|
||||
if (output_scale.has_value()) {
|
||||
output_scale_ptr = output_scale.value().data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
@@ -219,20 +287,44 @@ void merge_attn_states_launcher(
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
if (output_scale.has_value()) {
|
||||
// FP8 output path - dispatch on output FP8 type
|
||||
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
|
||||
});
|
||||
} else {
|
||||
// Original BF16/FP16/FP32 output path
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>( \
|
||||
output, output_lse, prefix_output, prefix_lse, suffix_output, \
|
||||
suffix_lse, prefill_tokens_with_context); \
|
||||
suffix_lse, prefill_tokens_with_context, output_scale); \
|
||||
}
|
||||
|
||||
void merge_attn_states(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context = std::nullopt) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
if (output_scale.has_value()) {
|
||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
|
||||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
|
||||
"output must be FP8 when output_scale is provided, got: ",
|
||||
output.scalar_type());
|
||||
} else {
|
||||
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
|
||||
"output dtype (", output.scalar_type(),
|
||||
") must match prefix_output dtype (",
|
||||
prefix_output.scalar_type(), ") when output_scale is not set");
|
||||
}
|
||||
// Always dispatch on prefix_output (input) dtype
|
||||
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
|
||||
CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user