[Kernel] Fuse FP8 output quantization into merge_attn_states (#36518)
Signed-off-by: Carl You <4531192+carlyou@users.noreply.github.com>
This commit is contained in:
264
benchmarks/fused_kernels/merge_attn_states_benchmarks.py
Normal file
264
benchmarks/fused_kernels/merge_attn_states_benchmarks.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark: Fused FP8 output quantization in merge_attn_states
|
||||
|
||||
Compares fused vs unfused approaches for producing FP8-quantized merged
|
||||
attention output:
|
||||
1. Fused CUDA -- single CUDA kernel (merge + FP8 quant)
|
||||
2. Fused Triton -- single Triton kernel (merge + FP8 quant)
|
||||
3. Unfused CUDA -- CUDA merge + torch.compiled FP8 quant
|
||||
4. Unfused Triton -- Triton merge + torch.compiled FP8 quant
|
||||
|
||||
Usage:
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --tp 1 4 8
|
||||
python benchmarks/fused_kernels/merge_attn_states_benchmarks.py --dtype bfloat16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
|
||||
from vllm.benchmarks.lib.utils import default_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states as merge_attn_states_triton,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
NUM_TOKENS_LIST = [1, 16, 64, 256, 1024, 4096]
|
||||
|
||||
# (label, num_heads, head_size) — num_heads is for TP=1
|
||||
HEAD_CONFIGS = [
|
||||
("DeepSeek-V3 MLA", 128, 128),
|
||||
("Llama-70B", 64, 128),
|
||||
("Llama-8B", 32, 128),
|
||||
]
|
||||
|
||||
TP_SIZES = [1, 2, 4, 8]
|
||||
|
||||
INPUT_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
QUANTILES = [0.5, 0.2, 0.8]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def short_dtype(dtype: torch.dtype) -> str:
|
||||
return str(dtype).removeprefix("torch.")
|
||||
|
||||
|
||||
def make_inputs(
|
||||
num_tokens: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
"""Create random prefix/suffix outputs and LSEs."""
|
||||
prefix_output = torch.randn(
|
||||
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(num_tokens, num_heads, head_size), dtype=dtype, device="cuda"
|
||||
)
|
||||
prefix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
|
||||
suffix_lse = torch.randn(num_heads, num_tokens, dtype=torch.float32, device="cuda")
|
||||
# Sprinkle some inf values to exercise edge-case paths
|
||||
mask = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
|
||||
prefix_lse[mask] = float("inf")
|
||||
mask2 = torch.rand(num_heads, num_tokens, device="cuda") < 0.05
|
||||
suffix_lse[mask2] = float("inf")
|
||||
return prefix_output, suffix_output, prefix_lse, suffix_lse
|
||||
|
||||
|
||||
def build_configs(head_configs, num_tokens_list, input_dtypes, tp_sizes):
|
||||
"""Build (num_tokens, num_heads, head_size, dtype_str) config tuples,
|
||||
applying TP division to num_heads and skipping invalid combos."""
|
||||
configs = []
|
||||
for (_, nh, hs), nt, dtype, tp in itertools.product(
|
||||
head_configs, num_tokens_list, input_dtypes, tp_sizes
|
||||
):
|
||||
nh_tp = nh // tp
|
||||
if nh_tp >= 1:
|
||||
configs.append((nt, nh_tp, hs, short_dtype(dtype)))
|
||||
return configs
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark merge_attn_states fused FP8 quantization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-tokens",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help=f"Override token counts (default: {NUM_TOKENS_LIST})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help=f"TP sizes to simulate (divides num_heads) (default: {TP_SIZES})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Input dtypes (e.g. bfloat16 float16 float32). "
|
||||
f"Default: {[short_dtype(d) for d in INPUT_DTYPES]}",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parse args and build configs before decorators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
args = parse_args()
|
||||
|
||||
num_tokens_list = args.num_tokens if args.num_tokens else NUM_TOKENS_LIST
|
||||
tp_sizes = args.tp if args.tp else TP_SIZES
|
||||
|
||||
if args.dtype:
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
|
||||
input_dtypes = [STR_DTYPE_TO_TORCH_DTYPE[d] for d in args.dtype]
|
||||
else:
|
||||
input_dtypes = INPUT_DTYPES
|
||||
|
||||
configs = build_configs(HEAD_CONFIGS, num_tokens_list, input_dtypes, tp_sizes)
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["num_tokens", "num_heads", "head_size", "dtype_str"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["fused_cuda", "fused_triton", "unfused_cuda", "unfused_triton"],
|
||||
line_names=["Fused CUDA", "Fused Triton", "Unfused CUDA", "Unfused Triton"],
|
||||
styles=[("blue", "-"), ("green", "-"), ("blue", "--"), ("green", "--")],
|
||||
ylabel="us",
|
||||
plot_name="merge_attn_states FP8 (fused vs unfused)",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
@default_vllm_config()
|
||||
def benchmark(num_tokens, num_heads, head_size, dtype_str, provider):
|
||||
input_dtype = getattr(torch, dtype_str)
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
prefix_out, suffix_out, prefix_lse, suffix_lse = make_inputs(
|
||||
num_tokens, num_heads, head_size, input_dtype
|
||||
)
|
||||
output_scale = torch.tensor([0.1], dtype=torch.float32, device="cuda")
|
||||
|
||||
if provider == "fused_cuda":
|
||||
output = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
|
||||
)
|
||||
fn = lambda: merge_attn_states_cuda(
|
||||
output,
|
||||
prefix_out,
|
||||
prefix_lse,
|
||||
suffix_out,
|
||||
suffix_lse,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
elif provider == "fused_triton":
|
||||
output = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=fp8_dtype, device="cuda"
|
||||
)
|
||||
fn = lambda: merge_attn_states_triton(
|
||||
output,
|
||||
prefix_out,
|
||||
prefix_lse,
|
||||
suffix_out,
|
||||
suffix_lse,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
elif provider == "unfused_cuda":
|
||||
merge_buf = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
quant_fp8 = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
column_major_scales=False,
|
||||
)
|
||||
quant_input = merge_buf.view(-1, head_size)
|
||||
compiled_quant = torch.compile(
|
||||
quant_fp8.forward_native, fullgraph=True, dynamic=False
|
||||
)
|
||||
|
||||
def unfused_fn():
|
||||
merge_attn_states_cuda(
|
||||
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
|
||||
)
|
||||
compiled_quant(quant_input, output_scale)
|
||||
|
||||
fn = unfused_fn
|
||||
else: # unfused_triton
|
||||
merge_buf = torch.empty(
|
||||
(num_tokens, num_heads, head_size), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
quant_fp8 = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape.PER_TENSOR,
|
||||
column_major_scales=False,
|
||||
)
|
||||
quant_input = merge_buf.view(-1, head_size)
|
||||
compiled_quant = torch.compile(
|
||||
quant_fp8.forward_native, fullgraph=True, dynamic=False
|
||||
)
|
||||
|
||||
def unfused_fn():
|
||||
merge_attn_states_triton(
|
||||
merge_buf, prefix_out, prefix_lse, suffix_out, suffix_lse
|
||||
)
|
||||
compiled_quant(quant_input, output_scale)
|
||||
|
||||
fn = unfused_fn
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=QUANTILES)
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # us
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main():
|
||||
device_name = current_platform.get_device_name()
|
||||
print(f"Device: {device_name}")
|
||||
print(f"Token counts: {num_tokens_list}")
|
||||
print(f"TP sizes: {tp_sizes}")
|
||||
print(f"Input dtypes: {[short_dtype(d) for d in input_dtypes]}")
|
||||
print(f"Head configs: {[(c[0], c[1], c[2]) for c in HEAD_CONFIGS]}")
|
||||
benchmark.run(print_data=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
with torch.inference_mode():
|
||||
main()
|
||||
@@ -7,19 +7,29 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "../quantization/w8a8/fp8/common.cuh"
|
||||
#include "../dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
// can be used to combine partial attention results (in the split-KV case)
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
template <typename scalar_t, typename output_t, const uint NUM_THREADS,
|
||||
bool USE_FP8_OUTPUT>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
output_t* output, float* output_lse, const scalar_t* prefix_output,
|
||||
const float* prefix_lse, const scalar_t* suffix_output,
|
||||
const float* suffix_lse, const uint num_tokens, const uint num_heads,
|
||||
const uint head_size, const uint prefix_head_stride,
|
||||
const uint output_head_stride, const uint prefix_num_tokens) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint output_head_stride, const uint prefix_num_tokens,
|
||||
const float* output_scale) {
|
||||
// Inputs always load 128-bit packs (pack_size elements of scalar_t).
|
||||
// Outputs store pack_size elements of output_t, which is smaller for FP8.
|
||||
using input_pack_t = uint4;
|
||||
using output_pack_t =
|
||||
std::conditional_t<USE_FP8_OUTPUT,
|
||||
std::conditional_t<sizeof(scalar_t) == 4, uint, uint2>,
|
||||
uint4>;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
@@ -42,15 +52,36 @@ __global__ void merge_attn_states_kernel(
|
||||
head_idx * output_head_stride;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
|
||||
scalar_t* output_head_ptr = output + dst_head_offset;
|
||||
output_t* output_head_ptr = output + dst_head_offset;
|
||||
|
||||
// Pre-invert scale: multiplication is faster than division
|
||||
float fp8_scale_inv = 1.0f;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
fp8_scale_inv = 1.0f / *output_scale;
|
||||
}
|
||||
|
||||
// If token_idx >= prefix_num_tokens, just copy from suffix
|
||||
if (token_idx >= prefix_num_tokens) {
|
||||
if (pack_offset < head_size) {
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
s_out_pack;
|
||||
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = s_out_pack;
|
||||
}
|
||||
}
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
@@ -70,20 +101,34 @@ __global__ void merge_attn_states_kernel(
|
||||
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
|
||||
continuing the pipeline then yields NaN. Root cause: with chunked prefill
|
||||
a batch may be split into two chunks; if a request in that batch has no
|
||||
prefix hit, every LSE entry for that request’s position is -inf, and at
|
||||
prefix hit, every LSE entry for that request's position is -inf, and at
|
||||
this moment we merge cross-attention at first. For now we simply emit
|
||||
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
|
||||
this problem.
|
||||
*/
|
||||
if (std::isinf(max_lse)) {
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
p_out_pack;
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
// Convert prefix values to FP8 (since -inf means no data,
|
||||
// prefix_output is expected to be zeros)
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
const float val =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
o_out_pack[i] =
|
||||
vllm::scaled_fp8_conversion<true, output_t>(val, fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = p_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -101,30 +146,43 @@ __global__ void merge_attn_states_kernel(
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t p_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
|
||||
input_pack_t s_out_pack = reinterpret_cast<const input_pack_t*>(
|
||||
suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
// Compute merged values in float32
|
||||
float o_out_f[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f =
|
||||
vllm::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
o_out_f[i] = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
|
||||
o_out_pack;
|
||||
// Convert and store
|
||||
if constexpr (USE_FP8_OUTPUT) {
|
||||
output_t o_out_pack[pack_size];
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
o_out_pack[i] = vllm::scaled_fp8_conversion<true, output_t>(
|
||||
o_out_f[i], fp8_scale_inv);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] =
|
||||
*reinterpret_cast<output_pack_t*>(o_out_pack);
|
||||
} else {
|
||||
output_pack_t o_out_pack;
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
vllm::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
|
||||
o_out_f[i]);
|
||||
}
|
||||
reinterpret_cast<output_pack_t*>(
|
||||
output_head_ptr)[pack_offset / pack_size] = o_out_pack;
|
||||
}
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
@@ -151,24 +209,26 @@ __global__ void merge_attn_states_kernel(
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT) \
|
||||
{ \
|
||||
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS> \
|
||||
vllm::merge_attn_states_kernel<scalar_t, output_t, NUM_THREADS, \
|
||||
USE_FP8_OUTPUT> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<output_t*>(output.data_ptr()), output_lse_ptr, \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), num_tokens, \
|
||||
num_heads, head_size, prefix_head_stride, output_head_stride, \
|
||||
prefix_num_tokens); \
|
||||
prefix_num_tokens, output_scale_ptr); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
@@ -180,19 +240,23 @@ __global__ void merge_attn_states_kernel(
|
||||
* is computed by merging prefix_output and suffix_output. For remaining tokens
|
||||
* (prefill_tokens_with_context <= token_idx < n), output is copied directly
|
||||
* from suffix_output.
|
||||
* @param output_scale Optional scalar tensor for FP8 static quantization.
|
||||
* When provided, output must be FP8 dtype.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context) {
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint prefix_head_stride = prefix_output.stride(1);
|
||||
const uint output_head_stride = output.stride(1);
|
||||
// Thread mapping is based on input BF16 pack_size
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0,
|
||||
"headsize must be multiple of pack_size:", pack_size);
|
||||
@@ -208,6 +272,10 @@ void merge_attn_states_launcher(
|
||||
if (output_lse.has_value()) {
|
||||
output_lse_ptr = output_lse.value().data_ptr<float>();
|
||||
}
|
||||
float* output_scale_ptr = nullptr;
|
||||
if (output_scale.has_value()) {
|
||||
output_scale_ptr = output_scale.value().data_ptr<float>();
|
||||
}
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
@@ -219,20 +287,44 @@ void merge_attn_states_launcher(
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(prefix_output.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
if (output_scale.has_value()) {
|
||||
// FP8 output path - dispatch on output FP8 type
|
||||
VLLM_DISPATCH_FP8_TYPES(output.scalar_type(), "merge_attn_states_fp8", [&] {
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, fp8_t, NUM_THREADS, true);
|
||||
});
|
||||
} else {
|
||||
// Original BF16/FP16/FP32 output path
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, scalar_t, NUM_THREADS, false);
|
||||
}
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ \
|
||||
merge_attn_states_launcher<scalar_t>( \
|
||||
output, output_lse, prefix_output, prefix_lse, suffix_output, \
|
||||
suffix_lse, prefill_tokens_with_context); \
|
||||
suffix_lse, prefill_tokens_with_context, output_scale); \
|
||||
}
|
||||
|
||||
void merge_attn_states(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context = std::nullopt) {
|
||||
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
void merge_attn_states(torch::Tensor& output,
|
||||
std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output,
|
||||
const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output,
|
||||
const torch::Tensor& suffix_lse,
|
||||
std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale) {
|
||||
if (output_scale.has_value()) {
|
||||
TORCH_CHECK(output.scalar_type() == at::ScalarType::Float8_e4m3fn ||
|
||||
output.scalar_type() == at::ScalarType::Float8_e4m3fnuz,
|
||||
"output must be FP8 when output_scale is provided, got: ",
|
||||
output.scalar_type());
|
||||
} else {
|
||||
TORCH_CHECK(output.scalar_type() == prefix_output.scalar_type(),
|
||||
"output dtype (", output.scalar_type(),
|
||||
") must match prefix_output dtype (",
|
||||
prefix_output.scalar_type(), ") when output_scale is not set");
|
||||
}
|
||||
// Always dispatch on prefix_output (input) dtype
|
||||
DISPATCH_BY_SCALAR_DTYPE(prefix_output.dtype(),
|
||||
CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
|
||||
@@ -57,7 +57,8 @@ void merge_attn_states(
|
||||
torch::Tensor& output, std::optional<torch::Tensor> output_lse,
|
||||
const torch::Tensor& prefix_output, const torch::Tensor& prefix_lse,
|
||||
const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse,
|
||||
const std::optional<int64_t> prefill_tokens_with_context);
|
||||
const std::optional<int64_t> prefill_tokens_with_context,
|
||||
const std::optional<torch::Tensor>& output_scale = std::nullopt);
|
||||
#ifndef USE_ROCM
|
||||
void convert_vertical_slash_indexes(
|
||||
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
|
||||
|
||||
@@ -73,7 +73,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor prefix_lse,"
|
||||
" Tensor suffix_output,"
|
||||
" Tensor suffix_lse,"
|
||||
" int!? prefill_tokens_with_context) -> ()");
|
||||
" int!? prefill_tokens_with_context,"
|
||||
" Tensor? output_scale=None) -> ()");
|
||||
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
|
||||
#ifndef USE_ROCM
|
||||
ops.def(
|
||||
|
||||
@@ -4,7 +4,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm._custom_ops import merge_attn_states as merge_attn_states_cuda
|
||||
from vllm._custom_ops import (
|
||||
merge_attn_states as merge_attn_states_cuda,
|
||||
)
|
||||
from vllm._custom_ops import (
|
||||
scaled_fp8_quant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states as merge_attn_states_triton,
|
||||
@@ -21,6 +26,7 @@ def merge_attn_states_torch(
|
||||
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
|
||||
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
output_scale: torch.Tensor | None = None, # scalar, per-tensor FP8 scale
|
||||
):
|
||||
# Apply prefill_tokens_with_context mask if needed
|
||||
if prefill_tokens_with_context is None:
|
||||
@@ -49,9 +55,13 @@ def merge_attn_states_torch(
|
||||
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
|
||||
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output.copy_(
|
||||
prefix_output * p_scale * mask + suffix_output * (s_scale * mask + (1 - mask))
|
||||
output = prefix_output * p_scale * mask + suffix_output * (
|
||||
s_scale * mask + (1 - mask)
|
||||
)
|
||||
if output_scale is not None:
|
||||
shape = output.shape
|
||||
output, _ = scaled_fp8_quant(output.float().view(-1, shape[-1]), output_scale)
|
||||
output = output.view(shape)
|
||||
return output, output_lse
|
||||
|
||||
|
||||
@@ -102,18 +112,20 @@ def generate_markdown_table():
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_fp8", [False, True])
|
||||
@pytest.mark.parametrize("prefill_tokens_with_context", [None, 128])
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@pytest.mark.parametrize("input_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(
|
||||
prefill_tokens_with_context: int | None,
|
||||
num_tokens: int,
|
||||
num_query_heads: int,
|
||||
head_size: int,
|
||||
output_dtype: torch.dtype,
|
||||
input_dtype: torch.dtype,
|
||||
use_fp8: bool,
|
||||
):
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(
|
||||
@@ -125,9 +137,18 @@ def test_merge_attn_states(
|
||||
NUM_HEADS = num_query_heads
|
||||
HEAD_SIZE = head_size
|
||||
|
||||
# When use_fp8 is set, inputs stay as input_dtype (bf16/fp16/fp32)
|
||||
# and output becomes FP8.
|
||||
output_dtype = input_dtype
|
||||
output_scale = None
|
||||
if use_fp8:
|
||||
output_dtype = current_platform.fp8_dtype()
|
||||
output_scale = torch.tensor([0.05], dtype=torch.float32, device="cuda")
|
||||
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, input_dtype: {input_dtype}, "
|
||||
f"output_dtype: {output_dtype}, use_fp8: {use_fp8}, "
|
||||
f"prefill_tokens_with_context: {prefill_tokens_with_context}, "
|
||||
f"Device: {current_platform.get_device_name()}"
|
||||
)
|
||||
@@ -156,10 +177,10 @@ def test_merge_attn_states(
|
||||
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
prefix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=input_dtype, device="cuda"
|
||||
)
|
||||
|
||||
warmup_times = 2
|
||||
@@ -183,6 +204,7 @@ def test_merge_attn_states(
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -196,6 +218,7 @@ def test_merge_attn_states(
|
||||
suffix_lse_torch,
|
||||
output_lse_torch,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -220,6 +243,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -233,6 +257,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_ref_triton,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -254,6 +279,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
@@ -267,6 +293,7 @@ def test_merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse_cuda,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
end.record()
|
||||
torch.accelerator.synchronize()
|
||||
@@ -288,7 +315,19 @@ def test_merge_attn_states(
|
||||
# Liger Kernel: Efficient Triton Kernels for LLM Training
|
||||
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
|
||||
# use rtol = 1e-2 for bfloat16.
|
||||
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
|
||||
if use_fp8:
|
||||
# Compare in dequantized space (multiply back by scale) so that
|
||||
# absolute differences reflect real precision, not amplified FP8
|
||||
# quantization steps.
|
||||
atol, rtol = 1e-1, 1e-1
|
||||
assert output_scale is not None
|
||||
scale = output_scale.item()
|
||||
elif output_dtype == torch.bfloat16:
|
||||
atol, rtol = 1e-3, 1e-2
|
||||
scale = 1.0
|
||||
else:
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
scale = 1.0
|
||||
|
||||
def diff(a: torch.Tensor, b: torch.Tensor):
|
||||
max_diff = torch.max(torch.abs(a.float() - b.float()))
|
||||
@@ -300,16 +339,26 @@ def test_merge_attn_states(
|
||||
output_ref = output_ref_triton
|
||||
output_lse_ref = output_lse_ref_triton
|
||||
torch.testing.assert_close(
|
||||
output_cuda.float(), output_ref.float(), atol=1e-3, rtol=rtol
|
||||
output_cuda.float() * scale,
|
||||
output_ref.float() * scale,
|
||||
atol=atol,
|
||||
rtol=rtol,
|
||||
)
|
||||
print("Output all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
|
||||
print(f" (CUDA vs Torch) : {diff(output_torch, output_cuda)}")
|
||||
print(f" (CUDA vs Triton): {diff(output_ref, output_cuda)}")
|
||||
print(
|
||||
"Output all match, max abs diff (dequantized):"
|
||||
if use_fp8
|
||||
else "Output all match, max abs diff:"
|
||||
)
|
||||
_diff = diff(output_ref.float() * scale, output_torch.float() * scale)
|
||||
print(f"(Triton vs Torch) : {_diff}")
|
||||
_diff = diff(output_torch.float() * scale, output_cuda.float() * scale)
|
||||
print(f" (CUDA vs Torch) : {_diff}")
|
||||
_diff = diff(output_ref.float() * scale, output_cuda.float() * scale)
|
||||
print(f" (CUDA vs Triton): {_diff}")
|
||||
print("-" * 100)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_lse_cuda.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
|
||||
output_lse_cuda.float(), output_lse_ref.float(), atol=atol, rtol=rtol
|
||||
)
|
||||
print("Output LSE all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
|
||||
|
||||
@@ -265,6 +265,7 @@ def merge_attn_states(
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
torch.ops._C.merge_attn_states(
|
||||
output,
|
||||
@@ -274,6 +275,7 @@ def merge_attn_states(
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ def merge_attn_states(
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
"""Merge partial attention outputs from prefix (KV cache) and suffix
|
||||
(new tokens) into a single output tensor using the log-sum-exp (LSE)
|
||||
@@ -41,27 +42,37 @@ def merge_attn_states(
|
||||
>= this value are decode or context-free prefill tokens whose
|
||||
output is taken directly from suffix_output. If None, all tokens
|
||||
are treated as having context.
|
||||
output_scale: Optional scalar tensor for FP8 static quantization.
|
||||
When provided, output must be FP8 dtype.
|
||||
"""
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA kernel
|
||||
# does not support FP8 dtype, fallback to use Triton kernel.
|
||||
def supported_dtypes(o: torch.Tensor) -> bool:
|
||||
return o.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
# does not support FP8 dtype for inputs, fallback to use Triton kernel.
|
||||
# However, when output_scale is provided, the inputs are still BF16/FP16
|
||||
# and the output is FP8 — both CUDA and Triton support this.
|
||||
# FP8 output requires output_scale to be set.
|
||||
if output.dtype not in (torch.float32, torch.half, torch.bfloat16):
|
||||
assert output_scale is not None, (
|
||||
f"output_scale is required when output is {output.dtype}"
|
||||
)
|
||||
|
||||
def supported_dtypes(prefix: torch.Tensor) -> bool:
|
||||
return prefix.dtype in [torch.float32, torch.half, torch.bfloat16]
|
||||
|
||||
# NOTE(DefTruth): Currently, custom merge_attn_states CUDA
|
||||
# kernel load/store 128b(16 bytes) per memory issue within
|
||||
# thread. Namely, the headsize(headdim) must be multiple of
|
||||
# pack_size (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(o: torch.Tensor) -> bool:
|
||||
headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if o.dtype == torch.float32:
|
||||
# pack_size based on input dtype (float32 -> 4, half/bfloat16 -> 8).
|
||||
def supported_headdim(prefix: torch.Tensor) -> bool:
|
||||
headdim = prefix.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
if prefix.dtype == torch.float32:
|
||||
return headdim % 4 == 0
|
||||
return headdim % 8 == 0
|
||||
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and supported_dtypes(output)
|
||||
and supported_headdim(output)
|
||||
and supported_dtypes(prefix_output)
|
||||
and supported_headdim(prefix_output)
|
||||
):
|
||||
from vllm._custom_ops import merge_attn_states
|
||||
|
||||
@@ -73,9 +84,12 @@ def merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
else:
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import merge_attn_states
|
||||
from vllm.v1.attention.ops.triton_merge_attn_states import (
|
||||
merge_attn_states,
|
||||
)
|
||||
|
||||
return merge_attn_states(
|
||||
output,
|
||||
@@ -85,4 +99,5 @@ def merge_attn_states(
|
||||
suffix_lse,
|
||||
output_lse,
|
||||
prefill_tokens_with_context,
|
||||
output_scale,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
# Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
# can be used to combine partial attention results (in the split-KV case)
|
||||
@@ -16,14 +19,15 @@ def merge_attn_states(
|
||||
suffix_lse: torch.Tensor,
|
||||
output_lse: torch.Tensor | None = None,
|
||||
prefill_tokens_with_context: int | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
# We assume the output stride on num_head is not always as same as the
|
||||
# `suffix_output` and `prefix_output`, as them might be padded by the attention
|
||||
# backend.
|
||||
# `suffix_output` and `prefix_output`, as them might be padded by the
|
||||
# attention backend.
|
||||
prefix_head_stride = prefix_output.stride(1)
|
||||
output_head_stride = output.stride(1)
|
||||
|
||||
@@ -41,10 +45,12 @@ def merge_attn_states(
|
||||
suffix_lse,
|
||||
prefix_head_stride,
|
||||
output_head_stride,
|
||||
output_scale,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
prefill_tokens_with_context,
|
||||
output_scale is not None,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,10 +64,14 @@ def merge_attn_states_kernel(
|
||||
suffix_lse, # [NUM_HEADS, NUM_TOKENS]
|
||||
prefix_head_stride,
|
||||
output_head_stride,
|
||||
output_scale, # scale tensor or None
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
prefill_tokens_with_context: tl.constexpr,
|
||||
USE_FP8: tl.constexpr,
|
||||
FP8_MIN: tl.constexpr = float8_info.min,
|
||||
FP8_MAX: tl.constexpr = float8_info.max,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
@@ -87,6 +97,12 @@ def merge_attn_states_kernel(
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
if USE_FP8:
|
||||
s_out = s_out * (1.0 / tl.load(output_scale))
|
||||
s_out = tl.clamp(s_out, FP8_MIN, FP8_MAX)
|
||||
s_out = s_out.to(output.dtype.element_ty)
|
||||
|
||||
tl.store(
|
||||
output
|
||||
+ token_idx * num_heads * output_head_stride
|
||||
@@ -143,6 +159,12 @@ def merge_attn_states_kernel(
|
||||
p_scale = p_se / out_se
|
||||
s_scale = s_se / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
|
||||
if USE_FP8:
|
||||
out = out * (1.0 / tl.load(output_scale))
|
||||
out = tl.clamp(out, FP8_MIN, FP8_MAX)
|
||||
out = out.to(output.dtype.element_ty)
|
||||
|
||||
tl.store(
|
||||
output
|
||||
+ token_idx * num_heads * output_head_stride
|
||||
|
||||
Reference in New Issue
Block a user