Files
vllm/csrc/cache_kernels_fused.cu
PatrykSaffer 80fead8bf6 Fuse RoPE and MLA KV-cache write (#25774)
Signed-off-by: Patryk Saffer <patryk.saffer99@gmail.com>
Signed-off-by: PatrykSaffer <patryk.saffer@mistral.ai>
Co-authored-by: Patryk Saffer <patryk.saffer99@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2026-01-09 19:18:37 -08:00

280 lines
11 KiB
Plaintext

#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/w8a8/fp8/common.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
#else
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
#endif
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#endif
namespace vllm {
// NOTE Be EXTRA careful with raw_kv_scalar_t, for __half and __nv_bfloat16 it's
// using u16 as the backing type.
template <typename qk_t, bool IS_NEOX, typename raw_kv_scalar_t,
typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_rope_fused_kernel(
const int64_t* __restrict__ positions, // [num_tokens]
qk_t* __restrict__ q_pe, // [num_tokens, num_q_heads, rot_dim]
qk_t* __restrict__ k_pe, // [num_tokens, rot_dim]
const qk_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const qk_t* __restrict__ rope_cos_sin_cache, // [max_position, 2,
// rot_dim // 2]
const int rot_dim, const int64_t q_pe_stride_token,
const int64_t q_pe_stride_head, const int64_t k_pe_stride,
const int64_t kv_c_stride, const int num_q_heads,
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// rot_dim)]
const int64_t* __restrict__ kv_cache_slot_mapping, // [num_tokens]
const int block_stride, const int entry_stride, const int kv_lora_rank,
const int block_size, const float* kv_cache_quant_scale) {
// Each thread block is responsible for one token.
const int64_t token_idx = blockIdx.x;
const int64_t pos = positions[token_idx];
const qk_t* cos_sin_ptr = rope_cos_sin_cache + pos * rot_dim;
const int embed_dim = rot_dim / 2;
// Q ROPE
const int nq = num_q_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
int head_idx = i / embed_dim;
int pair_idx = i % embed_dim;
// NOTE: Would be nice to have interleaved sin/cos so we could just load
// both at the same time.
qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx);
qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim);
qk_t* q_pe_head_ptr =
q_pe + token_idx * q_pe_stride_token + head_idx * q_pe_stride_head;
int pair_idx_x, pair_idx_y;
if constexpr (IS_NEOX) {
// GPT-NeoX style rotary embedding.
pair_idx_x = pair_idx;
pair_idx_y = embed_dim + pair_idx;
} else {
// GPT-J style rotary embedding.
pair_idx_x = pair_idx * 2;
pair_idx_y = pair_idx * 2 + 1;
}
qk_t x_src = q_pe_head_ptr[pair_idx_x];
qk_t y_src = q_pe_head_ptr[pair_idx_y];
qk_t x_dst = x_src * cos - y_src * sin;
qk_t y_dst = y_src * cos + x_src * sin;
q_pe_head_ptr[pair_idx_x] = x_dst;
q_pe_head_ptr[pair_idx_y] = y_dst;
}
const int64_t slot_idx = kv_cache_slot_mapping[token_idx];
const int64_t block_idx = slot_idx / block_size;
const int64_t entry_idx = slot_idx % block_size;
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
// K with 1 HEAD
for (int i = threadIdx.x; i < embed_dim; i += blockDim.x) {
int pair_idx = i;
qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx);
qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim);
qk_t* k_pe_head_ptr = k_pe + token_idx * k_pe_stride;
int pair_idx_x, pair_idx_y;
if constexpr (IS_NEOX) {
// GPT-NeoX style rotary embedding.
pair_idx_x = pair_idx;
pair_idx_y = embed_dim + pair_idx;
} else {
// GPT-J style rotary embedding.
pair_idx_x = pair_idx * 2;
pair_idx_y = pair_idx * 2 + 1;
}
qk_t x_src = k_pe_head_ptr[pair_idx_x];
qk_t y_src = k_pe_head_ptr[pair_idx_y];
qk_t x_dst = x_src * cos - y_src * sin;
qk_t y_dst = y_src * cos + x_src * sin;
k_pe_head_ptr[pair_idx_x] = x_dst;
k_pe_head_ptr[pair_idx_y] = y_dst;
// NOTE Why is this monster necessary?
// When K is of type float16, the actual template replacement for
// raw_kv_scalar_t with be u16. That's why it's used at the last moment
// otherwise CUDA ALU would break.
const raw_kv_scalar_t raw_x_value =
*reinterpret_cast<const raw_kv_scalar_t*>(&x_dst);
const raw_kv_scalar_t raw_y_value =
*reinterpret_cast<const raw_kv_scalar_t*>(&y_dst);
cache_t* kv_cache_ptr = kv_cache + block_idx * block_stride +
entry_idx * entry_stride + kv_lora_rank;
// MLA Cache Store
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
kv_cache_ptr[pair_idx_x] = raw_x_value;
kv_cache_ptr[pair_idx_y] = raw_y_value;
} else {
kv_cache_ptr[pair_idx_x] =
fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
raw_x_value, *kv_cache_quant_scale);
kv_cache_ptr[pair_idx_y] =
fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
raw_y_value, *kv_cache_quant_scale);
}
}
// NOPE
for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) {
const qk_t* src_ptr = kv_c + token_idx * kv_c_stride + i;
const raw_kv_scalar_t src_value =
*reinterpret_cast<const raw_kv_scalar_t*>(src_ptr);
cache_t* kv_cache_ptr =
kv_cache + block_idx * block_stride + entry_idx * entry_stride;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
kv_cache_ptr[i] = src_value;
} else {
kv_cache_ptr[i] = fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
src_value, *kv_cache_quant_scale);
}
}
}
} // namespace vllm
#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \
do { \
VLLM_DISPATCH_FLOATING_TYPES(q_pe.scalar_type(), "qk_scalar_type", [&] { \
using qk_t = scalar_t; \
if (rope_is_neox) { \
vllm::concat_and_cache_mla_rope_fused_kernel<qk_t, true, RAW_KV_T, \
CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<qk_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \
num_q_heads, reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
kv_cache_slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
} else { \
vllm::concat_and_cache_mla_rope_fused_kernel<qk_t, false, RAW_KV_T, \
CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
rope_cos_sin_cache.data_ptr<qk_t>(), rot_dim, \
q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \
num_q_heads, reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
kv_cache_slot_mapping.data_ptr<int64_t>(), block_stride, \
entry_stride, kv_lora_rank, block_size, \
kv_cache_quant_scale.data_ptr<float>()); \
} \
}); \
} while (false)
// Executes RoPE on q_pe and k_pe, then writes k_pe and kv_c in the kv cache.
// q_pe and k_pe are modified in place.
// Replaces DeepseekScalingRotaryEmbedding.self.rotary_emb and
// concat_and_cache_mla.
void concat_and_cache_mla_rope_fused(
torch::Tensor& positions, // [num_tokens]
torch::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim]
torch::Tensor& k_pe, // [num_tokens, rot_dim]
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& rope_cos_sin_cache, // [max_position, rot_dim]
bool rope_is_neox,
torch::Tensor&
kv_cache_slot_mapping, // [num_tokens] or [num_actual_tokens]
torch::Tensor&
kv_cache, // [num_blocks, block_size, (kv_lora_rank + rot_dim)]
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale) {
const int64_t num_tokens = q_pe.size(0);
const int num_q_heads = q_pe.size(1);
const int rot_dim = q_pe.size(2);
const int kv_lora_rank = kv_c.size(1);
TORCH_CHECK(positions.size(0) >=
num_tokens); // CUDA Graphs might pad this for us
TORCH_CHECK_EQ(positions.dim(), 1);
TORCH_CHECK_EQ(positions.scalar_type(), c10::ScalarType::Long);
TORCH_CHECK_EQ(q_pe.size(0), num_tokens);
TORCH_CHECK_EQ(q_pe.size(1), num_q_heads);
TORCH_CHECK_EQ(q_pe.size(2), rot_dim);
TORCH_CHECK_EQ(q_pe.dim(), 3);
TORCH_CHECK_EQ(k_pe.size(0), num_tokens);
TORCH_CHECK_EQ(k_pe.size(1), rot_dim);
TORCH_CHECK_EQ(k_pe.dim(), 2);
TORCH_CHECK_EQ(k_pe.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.size(0), num_tokens);
TORCH_CHECK_EQ(kv_c.size(1), kv_lora_rank);
TORCH_CHECK_EQ(kv_c.dim(), 2);
TORCH_CHECK_EQ(kv_c.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_c.dtype(), q_pe.dtype());
TORCH_CHECK_EQ(rope_cos_sin_cache.size(1), rot_dim);
TORCH_CHECK_EQ(rope_cos_sin_cache.scalar_type(), q_pe.scalar_type());
TORCH_CHECK_EQ(kv_cache_slot_mapping.size(0), num_tokens);
TORCH_CHECK_EQ(kv_cache_slot_mapping.scalar_type(), c10::ScalarType::Long);
TORCH_CHECK_EQ(kv_cache.size(2), kv_lora_rank + rot_dim);
TORCH_CHECK_EQ(kv_cache.dim(), 3);
TORCH_CHECK_EQ(kv_cache_quant_scale.numel(), 1);
TORCH_CHECK_EQ(kv_cache_quant_scale.scalar_type(), c10::ScalarType::Float);
int64_t q_pe_stride_token = q_pe.stride(0);
int64_t q_pe_stride_head = q_pe.stride(1);
int64_t k_pe_stride = k_pe.stride(0);
int64_t kv_c_stride = kv_c.stride(0);
int block_size = kv_cache.size(1);
int block_stride = kv_cache.stride(0);
int entry_stride = kv_cache.stride(1);
int rope_block_size = std::min(num_q_heads * rot_dim / 2, 512);
int mla_block_size = kv_lora_rank;
int thread_block_size =
std::min(std::max(rope_block_size, mla_block_size), 512);
dim3 grid(num_tokens, 1, 1);
dim3 block(thread_block_size, 1, 1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(positions));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED);
}