Add llmcompressor fp8 kv-cache quant (per-tensor and per-attn_head) (#30141)
Signed-off-by: Eldar Kurtic <8884008+eldarkurtic@users.noreply.github.com> Signed-off-by: eldarkurtic <8884008+eldarkurtic@users.noreply.github.com>
This commit is contained in:
@@ -202,7 +202,8 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t block_stride, const int64_t page_stride,
|
||||
const int64_t head_stride, const int64_t key_stride,
|
||||
const int64_t value_stride, const int num_heads, const int head_size,
|
||||
const int block_size, const float* k_scale, const float* v_scale) {
|
||||
const int block_size, const float* k_scale, const float* v_scale,
|
||||
const int kv_scale_stride) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@@ -226,21 +227,23 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
// this is true for the NHD layout where `head_stride == head_size`
|
||||
const bool is_contiguous_heads = (head_stride == head_size);
|
||||
|
||||
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
|
||||
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
|
||||
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
|
||||
if (is_contiguous_heads) {
|
||||
// NHD layout
|
||||
|
||||
if (is_contiguous_heads && kv_scale_stride == 0) {
|
||||
// NHD layout and k/v_scales are [1] (i.e. single scale for all heads)
|
||||
// kv cache: [num_blocks, block_size, num_heads, head_size]
|
||||
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
|
||||
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
|
||||
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
|
||||
|
||||
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
|
||||
blockDim.x, k_op);
|
||||
|
||||
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
|
||||
threadIdx.x, blockDim.x, v_op);
|
||||
|
||||
} else {
|
||||
// HND layout OR k/v_scales are [num_heads] (i.e. per-attn-head)
|
||||
// HND layout: heads are strided, but each head_size segment is contiguous
|
||||
// kv cache: [num_blocks, num_heads, block_size, head_size]
|
||||
const int lane = threadIdx.x & 31; // 0..31 within warp
|
||||
@@ -256,6 +259,16 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
cache_t* __restrict__ v_dst_h =
|
||||
value_dst + static_cast<int64_t>(head) * head_stride;
|
||||
|
||||
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
|
||||
? 0.f
|
||||
: k_scale[head * kv_scale_stride];
|
||||
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto)
|
||||
? 0.f
|
||||
: v_scale[head * kv_scale_stride];
|
||||
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
|
||||
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
|
||||
|
||||
// within each head, let the 32 threads of the warp perform the vector
|
||||
// copy
|
||||
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
|
||||
@@ -605,7 +618,8 @@ void reshape_and_cache(
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, page_stride, \
|
||||
head_stride, key_stride, value_stride, num_heads, head_size, \
|
||||
block_size, reinterpret_cast<const float*>(k_scale.data_ptr()), \
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()));
|
||||
reinterpret_cast<const float*>(v_scale.data_ptr()), \
|
||||
kv_scale_stride);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@@ -614,8 +628,9 @@ void reshape_and_cache_flash(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, // [1] or [num_heads]
|
||||
torch::Tensor& v_scale) { // [1] or [num_heads]
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
@@ -638,6 +653,12 @@ void reshape_and_cache_flash(
|
||||
int64_t head_stride = key_cache.stride(2);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
TORCH_CHECK(k_scale.sizes() == v_scale.sizes(),
|
||||
"k_scale and v_scale must have the same shape");
|
||||
TORCH_CHECK(k_scale.numel() == 1 || k_scale.numel() == num_heads,
|
||||
"k_scale and v_scale must be of shape [1] or [num_heads]");
|
||||
int kv_scale_stride = (k_scale.numel() > 1) ? 1 : 0;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
|
||||
Reference in New Issue
Block a user