diff --git a/CMakeLists.txt b/CMakeLists.txt index c46fb18d7..ec67ee8c3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -282,6 +282,7 @@ endif() set(VLLM_EXT_SRC "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" "csrc/cache_kernels.cu" + "csrc/cache_kernels_fused.cu" "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" "csrc/attention/merge_attn_states.cu" diff --git a/csrc/cache.h b/csrc/cache.h index 42ccb5896..d14f46c34 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -27,6 +27,13 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe, const std::string& kv_cache_dtype, torch::Tensor& scale); +// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla +void concat_and_cache_mla_rope_fused( + torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe, + torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox, + torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache, + const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale); + // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, const double scale, const std::string& kv_cache_dtype); diff --git a/csrc/cache_kernels_fused.cu b/csrc/cache_kernels_fused.cu new file mode 100644 index 000000000..be037b2fd --- /dev/null +++ b/csrc/cache_kernels_fused.cu @@ -0,0 +1,279 @@ +#include +#include +#include + +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#include "quantization/w8a8/fp8/common.cuh" +#ifdef USE_ROCM + #include "quantization/w8a8/fp8/amd/quant_utils.cuh" +#else + #include "quantization/w8a8/fp8/nvidia/quant_utils.cuh" +#endif + +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +namespace vllm { + +// NOTE Be EXTRA careful with raw_kv_scalar_t, for __half and __nv_bfloat16 it's +// using u16 as the backing type. +template +__global__ void concat_and_cache_mla_rope_fused_kernel( + const int64_t* __restrict__ positions, // [num_tokens] + qk_t* __restrict__ q_pe, // [num_tokens, num_q_heads, rot_dim] + qk_t* __restrict__ k_pe, // [num_tokens, rot_dim] + const qk_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const qk_t* __restrict__ rope_cos_sin_cache, // [max_position, 2, + // rot_dim // 2] + const int rot_dim, const int64_t q_pe_stride_token, + const int64_t q_pe_stride_head, const int64_t k_pe_stride, + const int64_t kv_c_stride, const int num_q_heads, + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + + // rot_dim)] + const int64_t* __restrict__ kv_cache_slot_mapping, // [num_tokens] + const int block_stride, const int entry_stride, const int kv_lora_rank, + const int block_size, const float* kv_cache_quant_scale) { + // Each thread block is responsible for one token. + const int64_t token_idx = blockIdx.x; + const int64_t pos = positions[token_idx]; + + const qk_t* cos_sin_ptr = rope_cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + + // Q ROPE + const int nq = num_q_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + int head_idx = i / embed_dim; + int pair_idx = i % embed_dim; + + // NOTE: Would be nice to have interleaved sin/cos so we could just load + // both at the same time. + qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx); + qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim); + + qk_t* q_pe_head_ptr = + q_pe + token_idx * q_pe_stride_token + head_idx * q_pe_stride_head; + + int pair_idx_x, pair_idx_y; + if constexpr (IS_NEOX) { + // GPT-NeoX style rotary embedding. + pair_idx_x = pair_idx; + pair_idx_y = embed_dim + pair_idx; + } else { + // GPT-J style rotary embedding. + pair_idx_x = pair_idx * 2; + pair_idx_y = pair_idx * 2 + 1; + } + + qk_t x_src = q_pe_head_ptr[pair_idx_x]; + qk_t y_src = q_pe_head_ptr[pair_idx_y]; + + qk_t x_dst = x_src * cos - y_src * sin; + qk_t y_dst = y_src * cos + x_src * sin; + + q_pe_head_ptr[pair_idx_x] = x_dst; + q_pe_head_ptr[pair_idx_y] = y_dst; + } + + const int64_t slot_idx = kv_cache_slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / block_size; + const int64_t entry_idx = slot_idx % block_size; + + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + + // K with 1 HEAD + for (int i = threadIdx.x; i < embed_dim; i += blockDim.x) { + int pair_idx = i; + + qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx); + qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim); + + qk_t* k_pe_head_ptr = k_pe + token_idx * k_pe_stride; + + int pair_idx_x, pair_idx_y; + if constexpr (IS_NEOX) { + // GPT-NeoX style rotary embedding. + pair_idx_x = pair_idx; + pair_idx_y = embed_dim + pair_idx; + } else { + // GPT-J style rotary embedding. + pair_idx_x = pair_idx * 2; + pair_idx_y = pair_idx * 2 + 1; + } + + qk_t x_src = k_pe_head_ptr[pair_idx_x]; + qk_t y_src = k_pe_head_ptr[pair_idx_y]; + + qk_t x_dst = x_src * cos - y_src * sin; + qk_t y_dst = y_src * cos + x_src * sin; + + k_pe_head_ptr[pair_idx_x] = x_dst; + k_pe_head_ptr[pair_idx_y] = y_dst; + + // NOTE Why is this monster necessary? + // When K is of type float16, the actual template replacement for + // raw_kv_scalar_t with be u16. That's why it's used at the last moment + // otherwise CUDA ALU would break. + const raw_kv_scalar_t raw_x_value = + *reinterpret_cast(&x_dst); + const raw_kv_scalar_t raw_y_value = + *reinterpret_cast(&y_dst); + + cache_t* kv_cache_ptr = kv_cache + block_idx * block_stride + + entry_idx * entry_stride + kv_lora_rank; + + // MLA Cache Store + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + kv_cache_ptr[pair_idx_x] = raw_x_value; + kv_cache_ptr[pair_idx_y] = raw_y_value; + } else { + kv_cache_ptr[pair_idx_x] = + fp8::scaled_convert( + raw_x_value, *kv_cache_quant_scale); + kv_cache_ptr[pair_idx_y] = + fp8::scaled_convert( + raw_y_value, *kv_cache_quant_scale); + } + } + + // NOPE + for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) { + const qk_t* src_ptr = kv_c + token_idx * kv_c_stride + i; + const raw_kv_scalar_t src_value = + *reinterpret_cast(src_ptr); + + cache_t* kv_cache_ptr = + kv_cache + block_idx * block_stride + entry_idx * entry_stride; + + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + kv_cache_ptr[i] = src_value; + } else { + kv_cache_ptr[i] = fp8::scaled_convert( + src_value, *kv_cache_quant_scale); + } + } +} + +} // namespace vllm + +#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \ + do { \ + VLLM_DISPATCH_FLOATING_TYPES(q_pe.scalar_type(), "qk_scalar_type", [&] { \ + using qk_t = scalar_t; \ + if (rope_is_neox) { \ + vllm::concat_and_cache_mla_rope_fused_kernel \ + <<>>( \ + positions.data_ptr(), q_pe.data_ptr(), \ + k_pe.data_ptr(), kv_c.data_ptr(), \ + rope_cos_sin_cache.data_ptr(), rot_dim, \ + q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \ + num_q_heads, reinterpret_cast(kv_cache.data_ptr()), \ + kv_cache_slot_mapping.data_ptr(), block_stride, \ + entry_stride, kv_lora_rank, block_size, \ + kv_cache_quant_scale.data_ptr()); \ + } else { \ + vllm::concat_and_cache_mla_rope_fused_kernel \ + <<>>( \ + positions.data_ptr(), q_pe.data_ptr(), \ + k_pe.data_ptr(), kv_c.data_ptr(), \ + rope_cos_sin_cache.data_ptr(), rot_dim, \ + q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \ + num_q_heads, reinterpret_cast(kv_cache.data_ptr()), \ + kv_cache_slot_mapping.data_ptr(), block_stride, \ + entry_stride, kv_lora_rank, block_size, \ + kv_cache_quant_scale.data_ptr()); \ + } \ + }); \ + } while (false) + +// Executes RoPE on q_pe and k_pe, then writes k_pe and kv_c in the kv cache. +// q_pe and k_pe are modified in place. +// Replaces DeepseekScalingRotaryEmbedding.self.rotary_emb and +// concat_and_cache_mla. +void concat_and_cache_mla_rope_fused( + torch::Tensor& positions, // [num_tokens] + torch::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim] + torch::Tensor& k_pe, // [num_tokens, rot_dim] + torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] + torch::Tensor& rope_cos_sin_cache, // [max_position, rot_dim] + bool rope_is_neox, + torch::Tensor& + kv_cache_slot_mapping, // [num_tokens] or [num_actual_tokens] + torch::Tensor& + kv_cache, // [num_blocks, block_size, (kv_lora_rank + rot_dim)] + const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale) { + const int64_t num_tokens = q_pe.size(0); + + const int num_q_heads = q_pe.size(1); + const int rot_dim = q_pe.size(2); + const int kv_lora_rank = kv_c.size(1); + + TORCH_CHECK(positions.size(0) >= + num_tokens); // CUDA Graphs might pad this for us + TORCH_CHECK_EQ(positions.dim(), 1); + TORCH_CHECK_EQ(positions.scalar_type(), c10::ScalarType::Long); + + TORCH_CHECK_EQ(q_pe.size(0), num_tokens); + TORCH_CHECK_EQ(q_pe.size(1), num_q_heads); + TORCH_CHECK_EQ(q_pe.size(2), rot_dim); + TORCH_CHECK_EQ(q_pe.dim(), 3); + + TORCH_CHECK_EQ(k_pe.size(0), num_tokens); + TORCH_CHECK_EQ(k_pe.size(1), rot_dim); + TORCH_CHECK_EQ(k_pe.dim(), 2); + TORCH_CHECK_EQ(k_pe.scalar_type(), q_pe.scalar_type()); + + TORCH_CHECK_EQ(kv_c.size(0), num_tokens); + TORCH_CHECK_EQ(kv_c.size(1), kv_lora_rank); + TORCH_CHECK_EQ(kv_c.dim(), 2); + TORCH_CHECK_EQ(kv_c.scalar_type(), q_pe.scalar_type()); + TORCH_CHECK_EQ(kv_c.dtype(), q_pe.dtype()); + + TORCH_CHECK_EQ(rope_cos_sin_cache.size(1), rot_dim); + TORCH_CHECK_EQ(rope_cos_sin_cache.scalar_type(), q_pe.scalar_type()); + + TORCH_CHECK_EQ(kv_cache_slot_mapping.size(0), num_tokens); + TORCH_CHECK_EQ(kv_cache_slot_mapping.scalar_type(), c10::ScalarType::Long); + + TORCH_CHECK_EQ(kv_cache.size(2), kv_lora_rank + rot_dim); + TORCH_CHECK_EQ(kv_cache.dim(), 3); + + TORCH_CHECK_EQ(kv_cache_quant_scale.numel(), 1); + TORCH_CHECK_EQ(kv_cache_quant_scale.scalar_type(), c10::ScalarType::Float); + + int64_t q_pe_stride_token = q_pe.stride(0); + int64_t q_pe_stride_head = q_pe.stride(1); + + int64_t k_pe_stride = k_pe.stride(0); + int64_t kv_c_stride = kv_c.stride(0); + + int block_size = kv_cache.size(1); + + int block_stride = kv_cache.stride(0); + int entry_stride = kv_cache.stride(1); + + int rope_block_size = std::min(num_q_heads * rot_dim / 2, 512); + int mla_block_size = kv_lora_rank; + int thread_block_size = + std::min(std::max(rope_block_size, mla_block_size), 512); + + dim3 grid(num_tokens, 1, 1); + dim3 block(thread_block_size, 1, 1); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(positions)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 38ff4e54a..864be7a26 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -737,6 +737,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor scale) -> ()"); cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla); + // Rotate Q and K, then write to kv cache for MLA + cache_ops.def( + "concat_and_cache_mla_rope_fused(" + " Tensor positions," + " Tensor! q_pe," + " Tensor! k_pe," + " Tensor kv_c," + " Tensor cos_sin_cache," + " bool is_neox," + " Tensor slot_mapping," + " Tensor! kv_cache," + " str kv_cache_dtype," + " Tensor kv_cache_scale) -> ()"); + cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA, + &concat_and_cache_mla_rope_fused); + // Convert the key and value cache to fp8 data type. cache_ops.def( "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " diff --git a/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py b/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py new file mode 100644 index 000000000..1a08eca4f --- /dev/null +++ b/tests/kernels/core/test_rotary_embedding_mla_cache_fused.py @@ -0,0 +1,161 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tests for fused MLA KV-cache write and RoPE fused kernel +""" + +import random + +import pytest +import torch + +from tests.kernels.allclose_default import get_default_atol, get_default_rtol +from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding +from vllm.platforms import current_platform + + +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float]) +@pytest.mark.parametrize("is_neox_style", [False, True]) +@pytest.mark.parametrize("seq_len", [11, 42]) +@pytest.mark.parametrize("qk_rope_head_dim", [64, 128]) +@pytest.mark.parametrize("num_q_heads", [128]) +@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) +@pytest.mark.parametrize("kv_lora_rank", [512]) +@pytest.mark.parametrize("num_blocks", [64]) +@pytest.mark.parametrize("block_size", [16, 64, 256]) +@pytest.mark.parametrize("seed", [0]) +@pytest.mark.parametrize( + "device", [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)] +) +@torch.inference_mode() +def test_concat_and_cache_mla_rope_fused( + dtype: torch.dtype, + is_neox_style: bool, + seq_len: int, + qk_rope_head_dim: int, + num_q_heads: int, + kv_cache_dtype: str, + kv_lora_rank: int, + num_blocks: int, + block_size: int, + seed: int, + device: str, + max_position: int = 8192, + base: float = 10000, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + + rope = RotaryEmbedding( + qk_rope_head_dim, + qk_rope_head_dim, + max_position, + base, + is_neox_style, + torch.float32, + ) + + rope = rope.to(dtype=dtype, device=torch.get_default_device()) + + positions = torch.randint(0, max_position, (seq_len,)) + + query = torch.randn(seq_len, num_q_heads, qk_rope_head_dim, dtype=dtype) + key = torch.randn(seq_len, 1, qk_rope_head_dim + kv_lora_rank, dtype=dtype) + + k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device) + kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device) + + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe) + assert ref_k_pe is not None + + ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device) + ref_k_rope = ref_k_pe[..., :qk_rope_head_dim] + + total_available_slots = num_blocks * block_size + total_needed_slots = seq_len + assert total_available_slots >= total_needed_slots, "Not enough kv slots!" + + slot_mapping_lst = random.sample(range(total_available_slots), total_needed_slots) + slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device) + + entry_size = kv_lora_rank + qk_rope_head_dim + + kv_cache_scale = torch.tensor([0.1], dtype=torch.float32, device=device) + + kv_cache = torch.zeros( + num_blocks, + block_size, + entry_size, + dtype=torch.uint8 if kv_cache_dtype == "fp8" else dtype, + device=device, + ) + + ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device) + + for i in range(seq_len): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + ref_temp[block_idx, block_offset] = torch.cat((kv_c[i], ref_k_rope[i]), -1) + + if kv_cache_dtype == "fp8": + ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype) + ops.convert_fp8( + ref_kv_cache, ref_temp, kv_cache_scale.item(), kv_dtype=kv_cache_dtype + ) + else: + ref_kv_cache = ref_temp + + opcheck( + torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused, + ( + positions, + query, + k_pe, + kv_c, + rope.cos_sin_cache, + is_neox_style, + slot_mapping, + kv_cache, + kv_cache_dtype, + kv_cache_scale, + ), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.concat_and_cache_mla_rope_fused( + positions, + query, + k_pe, + kv_c, + rope.cos_sin_cache, + is_neox_style, + slot_mapping, + kv_cache, + kv_cache_dtype, + kv_cache_scale, + ) + + if kv_cache_dtype == "fp8": + result_temp = torch.empty_like(kv_cache, dtype=torch.float16) + ops.convert_fp8( + result_temp, + kv_cache.contiguous(), + kv_cache_scale.item(), + kv_dtype=kv_cache_dtype, + ) + expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16) + ops.convert_fp8( + expected_temp, ref_kv_cache, kv_cache_scale.item(), kv_dtype=kv_cache_dtype + ) + torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1) + else: + torch.testing.assert_close(kv_cache, ref_kv_cache) + + torch.testing.assert_close( + query, ref_q_pe, atol=get_default_atol(query), rtol=get_default_rtol(query) + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c588798ae..86d6e309b 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2418,6 +2418,32 @@ def concat_and_cache_mla( ) +def concat_and_cache_mla_rope_fused( + positions: torch.Tensor, + q_pe: torch.Tensor, + k_pe: torch.Tensor, + kv_c: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + slot_mapping: torch.Tensor, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + kv_cache_scale: torch.Tensor, +) -> None: + torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused( + positions, + q_pe, + k_pe, + kv_c, + cos_sin_cache, + is_neox, + slot_mapping, + kv_cache, + kv_cache_dtype, + kv_cache_scale, + ) + + def swap_blocks( src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor ) -> None: