Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)

Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com>
Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
Eldar Kurtić
2026-01-22 21:29:57 +01:00
committed by GitHub
parent 955b43a5a5
commit 44f08af3a7
18 changed files with 558 additions and 263 deletions

View File

@@ -202,7 +202,8 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t block_stride, const int64_t page_stride,
const int64_t head_stride, const int64_t key_stride,
const int64_t value_stride, const int num_heads, const int head_size,
const int block_size, const float* k_scale, const float* v_scale) {
const int block_size, const float* k_scale, const float* v_scale,
const int kv_scale_stride) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
@@ -226,21 +227,23 @@ __global__ void reshape_and_cache_flash_kernel(
// this is true for the NHD layout where `head_stride == head_size`
const bool is_contiguous_heads = (head_stride == head_size);
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
if (is_contiguous_heads) {
// NHD layout
if (is_contiguous_heads && kv_scale_stride == 0) {
// NHD layout and k/v_scales are [1] (i.e. single scale for all heads)
// kv cache: [num_blocks, block_size, num_heads, head_size]
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
blockDim.x, k_op);
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
threadIdx.x, blockDim.x, v_op);
} else {
// HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head)
// HND layout: heads are strided, but each head_size segment is contiguous
// kv cache: [num_blocks, num_heads, block_size, head_size]
const int lane = threadIdx.x & 31; // 0..31 within warp
@@ -256,6 +259,16 @@ __global__ void reshape_and_cache_flash_kernel(
cache_t* __restrict__ v_dst_h =
value_dst + static_cast<int64_t>(head) * head_stride;
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: k_scale[head * kv_scale_stride];
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
? 0.f
: v_scale[head * kv_scale_stride];
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
// within each head, let the 32 threads of the warp perform the vector
// copy
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
@@ -605,7 +618,8 @@ void reshape_and_cache(
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
head_stride, key_stride, value_stride, num_heads, head_size, \
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
reinterpret_cast<const float*>(v_scale.data_ptr()), \
kv_scale_stride);
void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
@@ -614,8 +628,9 @@ void reshape_and_cache_flash(
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, // [1] or [num_heads]
torch::Tensor& v_scale) { // [1] or [num_heads]
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
@@ -638,6 +653,12 @@ void reshape_and_cache_flash(
int64_t head_stride = key_cache.stride(2);
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
TORCH_CHECK(k_scale.sizes() == v_scale.sizes(),
"k_scale and v_scale must have the same shape");
TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
"k_scale and v_scale must be of shape [1] or [num_heads]");
int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0;
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));