Fuse RoPE and MLA KV-cache write (#25774)
Signed-off-by: Patryk Saffer <patryk.saffer99@gmail.com> Signed-off-by: PatrykSaffer <patryk.saffer@mistral.ai> Co-authored-by: Patryk Saffer <patryk.saffer99@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -282,6 +282,7 @@ endif()
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/cache_kernels_fused.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
|
||||
@@ -27,6 +27,13 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
// NOTE: k_pe and kv_c order is flipped compared to concat_and_cache_mla
|
||||
void concat_and_cache_mla_rope_fused(
|
||||
torch::Tensor& positions, torch::Tensor& q_pe, torch::Tensor& k_pe,
|
||||
torch::Tensor& kv_c, torch::Tensor& rope_cos_sin_cache, bool rope_is_neox,
|
||||
torch::Tensor& kv_cache_slot_mapping, torch::Tensor& kv_cache,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
||||
279
csrc/cache_kernels_fused.cu
Normal file
279
csrc/cache_kernels_fused.cu
Normal file
@@ -0,0 +1,279 @@
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "quantization/w8a8/fp8/common.cuh"
|
||||
#ifdef USE_ROCM
|
||||
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
|
||||
#else
|
||||
#include "quantization/w8a8/fp8/nvidia/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// NOTE Be EXTRA careful with raw_kv_scalar_t, for __half and __nv_bfloat16 it's
|
||||
// using u16 as the backing type.
|
||||
template <typename qk_t, bool IS_NEOX, typename raw_kv_scalar_t,
|
||||
typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_mla_rope_fused_kernel(
|
||||
const int64_t* __restrict__ positions, // [num_tokens]
|
||||
qk_t* __restrict__ q_pe, // [num_tokens, num_q_heads, rot_dim]
|
||||
qk_t* __restrict__ k_pe, // [num_tokens, rot_dim]
|
||||
const qk_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const qk_t* __restrict__ rope_cos_sin_cache, // [max_position, 2,
|
||||
// rot_dim // 2]
|
||||
const int rot_dim, const int64_t q_pe_stride_token,
|
||||
const int64_t q_pe_stride_head, const int64_t k_pe_stride,
|
||||
const int64_t kv_c_stride, const int num_q_heads,
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// rot_dim)]
|
||||
const int64_t* __restrict__ kv_cache_slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int entry_stride, const int kv_lora_rank,
|
||||
const int block_size, const float* kv_cache_quant_scale) {
|
||||
// Each thread block is responsible for one token.
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t pos = positions[token_idx];
|
||||
|
||||
const qk_t* cos_sin_ptr = rope_cos_sin_cache + pos * rot_dim;
|
||||
|
||||
const int embed_dim = rot_dim / 2;
|
||||
|
||||
// Q ROPE
|
||||
const int nq = num_q_heads * embed_dim;
|
||||
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
|
||||
int head_idx = i / embed_dim;
|
||||
int pair_idx = i % embed_dim;
|
||||
|
||||
// NOTE: Would be nice to have interleaved sin/cos so we could just load
|
||||
// both at the same time.
|
||||
qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx);
|
||||
qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim);
|
||||
|
||||
qk_t* q_pe_head_ptr =
|
||||
q_pe + token_idx * q_pe_stride_token + head_idx * q_pe_stride_head;
|
||||
|
||||
int pair_idx_x, pair_idx_y;
|
||||
if constexpr (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
pair_idx_x = pair_idx;
|
||||
pair_idx_y = embed_dim + pair_idx;
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
pair_idx_x = pair_idx * 2;
|
||||
pair_idx_y = pair_idx * 2 + 1;
|
||||
}
|
||||
|
||||
qk_t x_src = q_pe_head_ptr[pair_idx_x];
|
||||
qk_t y_src = q_pe_head_ptr[pair_idx_y];
|
||||
|
||||
qk_t x_dst = x_src * cos - y_src * sin;
|
||||
qk_t y_dst = y_src * cos + x_src * sin;
|
||||
|
||||
q_pe_head_ptr[pair_idx_x] = x_dst;
|
||||
q_pe_head_ptr[pair_idx_y] = y_dst;
|
||||
}
|
||||
|
||||
const int64_t slot_idx = kv_cache_slot_mapping[token_idx];
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t entry_idx = slot_idx % block_size;
|
||||
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// K with 1 HEAD
|
||||
for (int i = threadIdx.x; i < embed_dim; i += blockDim.x) {
|
||||
int pair_idx = i;
|
||||
|
||||
qk_t cos = VLLM_LDG(cos_sin_ptr + pair_idx);
|
||||
qk_t sin = VLLM_LDG(cos_sin_ptr + pair_idx + embed_dim);
|
||||
|
||||
qk_t* k_pe_head_ptr = k_pe + token_idx * k_pe_stride;
|
||||
|
||||
int pair_idx_x, pair_idx_y;
|
||||
if constexpr (IS_NEOX) {
|
||||
// GPT-NeoX style rotary embedding.
|
||||
pair_idx_x = pair_idx;
|
||||
pair_idx_y = embed_dim + pair_idx;
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
pair_idx_x = pair_idx * 2;
|
||||
pair_idx_y = pair_idx * 2 + 1;
|
||||
}
|
||||
|
||||
qk_t x_src = k_pe_head_ptr[pair_idx_x];
|
||||
qk_t y_src = k_pe_head_ptr[pair_idx_y];
|
||||
|
||||
qk_t x_dst = x_src * cos - y_src * sin;
|
||||
qk_t y_dst = y_src * cos + x_src * sin;
|
||||
|
||||
k_pe_head_ptr[pair_idx_x] = x_dst;
|
||||
k_pe_head_ptr[pair_idx_y] = y_dst;
|
||||
|
||||
// NOTE Why is this monster necessary?
|
||||
// When K is of type float16, the actual template replacement for
|
||||
// raw_kv_scalar_t with be u16. That's why it's used at the last moment
|
||||
// otherwise CUDA ALU would break.
|
||||
const raw_kv_scalar_t raw_x_value =
|
||||
*reinterpret_cast<const raw_kv_scalar_t*>(&x_dst);
|
||||
const raw_kv_scalar_t raw_y_value =
|
||||
*reinterpret_cast<const raw_kv_scalar_t*>(&y_dst);
|
||||
|
||||
cache_t* kv_cache_ptr = kv_cache + block_idx * block_stride +
|
||||
entry_idx * entry_stride + kv_lora_rank;
|
||||
|
||||
// MLA Cache Store
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
kv_cache_ptr[pair_idx_x] = raw_x_value;
|
||||
kv_cache_ptr[pair_idx_y] = raw_y_value;
|
||||
} else {
|
||||
kv_cache_ptr[pair_idx_x] =
|
||||
fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
|
||||
raw_x_value, *kv_cache_quant_scale);
|
||||
kv_cache_ptr[pair_idx_y] =
|
||||
fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
|
||||
raw_y_value, *kv_cache_quant_scale);
|
||||
}
|
||||
}
|
||||
|
||||
// NOPE
|
||||
for (int i = threadIdx.x; i < kv_lora_rank; i += blockDim.x) {
|
||||
const qk_t* src_ptr = kv_c + token_idx * kv_c_stride + i;
|
||||
const raw_kv_scalar_t src_value =
|
||||
*reinterpret_cast<const raw_kv_scalar_t*>(src_ptr);
|
||||
|
||||
cache_t* kv_cache_ptr =
|
||||
kv_cache + block_idx * block_stride + entry_idx * entry_stride;
|
||||
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
kv_cache_ptr[i] = src_value;
|
||||
} else {
|
||||
kv_cache_ptr[i] = fp8::scaled_convert<cache_t, raw_kv_scalar_t, kv_dt>(
|
||||
src_value, *kv_cache_quant_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#define CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED(RAW_KV_T, CACHE_T, KV_DTYPE) \
|
||||
do { \
|
||||
VLLM_DISPATCH_FLOATING_TYPES(q_pe.scalar_type(), "qk_scalar_type", [&] { \
|
||||
using qk_t = scalar_t; \
|
||||
if (rope_is_neox) { \
|
||||
vllm::concat_and_cache_mla_rope_fused_kernel<qk_t, true, RAW_KV_T, \
|
||||
CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
|
||||
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
|
||||
rope_cos_sin_cache.data_ptr<qk_t>(), rot_dim, \
|
||||
q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \
|
||||
num_q_heads, reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
kv_cache_slot_mapping.data_ptr<int64_t>(), block_stride, \
|
||||
entry_stride, kv_lora_rank, block_size, \
|
||||
kv_cache_quant_scale.data_ptr<float>()); \
|
||||
} else { \
|
||||
vllm::concat_and_cache_mla_rope_fused_kernel<qk_t, false, RAW_KV_T, \
|
||||
CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
positions.data_ptr<int64_t>(), q_pe.data_ptr<qk_t>(), \
|
||||
k_pe.data_ptr<qk_t>(), kv_c.data_ptr<qk_t>(), \
|
||||
rope_cos_sin_cache.data_ptr<qk_t>(), rot_dim, \
|
||||
q_pe_stride_token, q_pe_stride_head, k_pe_stride, kv_c_stride, \
|
||||
num_q_heads, reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
kv_cache_slot_mapping.data_ptr<int64_t>(), block_stride, \
|
||||
entry_stride, kv_lora_rank, block_size, \
|
||||
kv_cache_quant_scale.data_ptr<float>()); \
|
||||
} \
|
||||
}); \
|
||||
} while (false)
|
||||
|
||||
// Executes RoPE on q_pe and k_pe, then writes k_pe and kv_c in the kv cache.
|
||||
// q_pe and k_pe are modified in place.
|
||||
// Replaces DeepseekScalingRotaryEmbedding.self.rotary_emb and
|
||||
// concat_and_cache_mla.
|
||||
void concat_and_cache_mla_rope_fused(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
torch::Tensor& q_pe, // [num_tokens, num_q_heads, rot_dim]
|
||||
torch::Tensor& k_pe, // [num_tokens, rot_dim]
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& rope_cos_sin_cache, // [max_position, rot_dim]
|
||||
bool rope_is_neox,
|
||||
torch::Tensor&
|
||||
kv_cache_slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
torch::Tensor&
|
||||
kv_cache, // [num_blocks, block_size, (kv_lora_rank + rot_dim)]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& kv_cache_quant_scale) {
|
||||
const int64_t num_tokens = q_pe.size(0);
|
||||
|
||||
const int num_q_heads = q_pe.size(1);
|
||||
const int rot_dim = q_pe.size(2);
|
||||
const int kv_lora_rank = kv_c.size(1);
|
||||
|
||||
TORCH_CHECK(positions.size(0) >=
|
||||
num_tokens); // CUDA Graphs might pad this for us
|
||||
TORCH_CHECK_EQ(positions.dim(), 1);
|
||||
TORCH_CHECK_EQ(positions.scalar_type(), c10::ScalarType::Long);
|
||||
|
||||
TORCH_CHECK_EQ(q_pe.size(0), num_tokens);
|
||||
TORCH_CHECK_EQ(q_pe.size(1), num_q_heads);
|
||||
TORCH_CHECK_EQ(q_pe.size(2), rot_dim);
|
||||
TORCH_CHECK_EQ(q_pe.dim(), 3);
|
||||
|
||||
TORCH_CHECK_EQ(k_pe.size(0), num_tokens);
|
||||
TORCH_CHECK_EQ(k_pe.size(1), rot_dim);
|
||||
TORCH_CHECK_EQ(k_pe.dim(), 2);
|
||||
TORCH_CHECK_EQ(k_pe.scalar_type(), q_pe.scalar_type());
|
||||
|
||||
TORCH_CHECK_EQ(kv_c.size(0), num_tokens);
|
||||
TORCH_CHECK_EQ(kv_c.size(1), kv_lora_rank);
|
||||
TORCH_CHECK_EQ(kv_c.dim(), 2);
|
||||
TORCH_CHECK_EQ(kv_c.scalar_type(), q_pe.scalar_type());
|
||||
TORCH_CHECK_EQ(kv_c.dtype(), q_pe.dtype());
|
||||
|
||||
TORCH_CHECK_EQ(rope_cos_sin_cache.size(1), rot_dim);
|
||||
TORCH_CHECK_EQ(rope_cos_sin_cache.scalar_type(), q_pe.scalar_type());
|
||||
|
||||
TORCH_CHECK_EQ(kv_cache_slot_mapping.size(0), num_tokens);
|
||||
TORCH_CHECK_EQ(kv_cache_slot_mapping.scalar_type(), c10::ScalarType::Long);
|
||||
|
||||
TORCH_CHECK_EQ(kv_cache.size(2), kv_lora_rank + rot_dim);
|
||||
TORCH_CHECK_EQ(kv_cache.dim(), 3);
|
||||
|
||||
TORCH_CHECK_EQ(kv_cache_quant_scale.numel(), 1);
|
||||
TORCH_CHECK_EQ(kv_cache_quant_scale.scalar_type(), c10::ScalarType::Float);
|
||||
|
||||
int64_t q_pe_stride_token = q_pe.stride(0);
|
||||
int64_t q_pe_stride_head = q_pe.stride(1);
|
||||
|
||||
int64_t k_pe_stride = k_pe.stride(0);
|
||||
int64_t kv_c_stride = kv_c.stride(0);
|
||||
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
int block_stride = kv_cache.stride(0);
|
||||
int entry_stride = kv_cache.stride(1);
|
||||
|
||||
int rope_block_size = std::min(num_q_heads * rot_dim / 2, 512);
|
||||
int mla_block_size = kv_lora_rank;
|
||||
int thread_block_size =
|
||||
std::min(std::max(rope_block_size, mla_block_size), 512);
|
||||
|
||||
dim3 grid(num_tokens, 1, 1);
|
||||
dim3 block(thread_block_size, 1, 1);
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(positions));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA_ROPE_FUSED);
|
||||
}
|
||||
@@ -737,6 +737,22 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
||||
|
||||
// Rotate Q and K, then write to kv cache for MLA
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla_rope_fused("
|
||||
" Tensor positions,"
|
||||
" Tensor! q_pe,"
|
||||
" Tensor! k_pe,"
|
||||
" Tensor kv_c,"
|
||||
" Tensor cos_sin_cache,"
|
||||
" bool is_neox,"
|
||||
" Tensor slot_mapping,"
|
||||
" Tensor! kv_cache,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor kv_cache_scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla_rope_fused", torch::kCUDA,
|
||||
&concat_and_cache_mla_rope_fused);
|
||||
|
||||
// Convert the key and value cache to fp8 data type.
|
||||
cache_ops.def(
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
|
||||
161
tests/kernels/core/test_rotary_embedding_mla_cache_fused.py
Normal file
161
tests/kernels/core/test_rotary_embedding_mla_cache_fused.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tests for fused MLA KV-cache write and RoPE fused kernel
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from tests.kernels.utils import DEFAULT_OPCHECK_TEST_UTILS, opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16, torch.float])
|
||||
@pytest.mark.parametrize("is_neox_style", [False, True])
|
||||
@pytest.mark.parametrize("seq_len", [11, 42])
|
||||
@pytest.mark.parametrize("qk_rope_head_dim", [64, 128])
|
||||
@pytest.mark.parametrize("num_q_heads", [128])
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
@pytest.mark.parametrize("kv_lora_rank", [512])
|
||||
@pytest.mark.parametrize("num_blocks", [64])
|
||||
@pytest.mark.parametrize("block_size", [16, 64, 256])
|
||||
@pytest.mark.parametrize("seed", [0])
|
||||
@pytest.mark.parametrize(
|
||||
"device", [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
)
|
||||
@torch.inference_mode()
|
||||
def test_concat_and_cache_mla_rope_fused(
|
||||
dtype: torch.dtype,
|
||||
is_neox_style: bool,
|
||||
seq_len: int,
|
||||
qk_rope_head_dim: int,
|
||||
num_q_heads: int,
|
||||
kv_cache_dtype: str,
|
||||
kv_lora_rank: int,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: float = 10000,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
rope = RotaryEmbedding(
|
||||
qk_rope_head_dim,
|
||||
qk_rope_head_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
torch.float32,
|
||||
)
|
||||
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (seq_len,))
|
||||
|
||||
query = torch.randn(seq_len, num_q_heads, qk_rope_head_dim, dtype=dtype)
|
||||
key = torch.randn(seq_len, 1, qk_rope_head_dim + kv_lora_rank, dtype=dtype)
|
||||
|
||||
k_pe = torch.flatten(key[..., :qk_rope_head_dim], start_dim=1).to(device=device)
|
||||
kv_c = torch.flatten(key[..., qk_rope_head_dim:], start_dim=1).to(device=device)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
# because the custom kernel is in-place.
|
||||
ref_q_pe, ref_k_pe = rope.forward_native(positions, query, k_pe)
|
||||
assert ref_k_pe is not None
|
||||
|
||||
ref_k_pe = torch.flatten(ref_k_pe, start_dim=1).to(device=device)
|
||||
ref_k_rope = ref_k_pe[..., :qk_rope_head_dim]
|
||||
|
||||
total_available_slots = num_blocks * block_size
|
||||
total_needed_slots = seq_len
|
||||
assert total_available_slots >= total_needed_slots, "Not enough kv slots!"
|
||||
|
||||
slot_mapping_lst = random.sample(range(total_available_slots), total_needed_slots)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
|
||||
|
||||
entry_size = kv_lora_rank + qk_rope_head_dim
|
||||
|
||||
kv_cache_scale = torch.tensor([0.1], dtype=torch.float32, device=device)
|
||||
|
||||
kv_cache = torch.zeros(
|
||||
num_blocks,
|
||||
block_size,
|
||||
entry_size,
|
||||
dtype=torch.uint8 if kv_cache_dtype == "fp8" else dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
ref_temp = torch.zeros(*kv_cache.shape, dtype=dtype, device=device)
|
||||
|
||||
for i in range(seq_len):
|
||||
slot = slot_mapping[i].item()
|
||||
block_idx = slot // block_size
|
||||
block_offset = slot % block_size
|
||||
ref_temp[block_idx, block_offset] = torch.cat((kv_c[i], ref_k_rope[i]), -1)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
ref_kv_cache = torch.empty_like(ref_temp, dtype=kv_cache.dtype)
|
||||
ops.convert_fp8(
|
||||
ref_kv_cache, ref_temp, kv_cache_scale.item(), kv_dtype=kv_cache_dtype
|
||||
)
|
||||
else:
|
||||
ref_kv_cache = ref_temp
|
||||
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused,
|
||||
(
|
||||
positions,
|
||||
query,
|
||||
k_pe,
|
||||
kv_c,
|
||||
rope.cos_sin_cache,
|
||||
is_neox_style,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
kv_cache_dtype,
|
||||
kv_cache_scale,
|
||||
),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.concat_and_cache_mla_rope_fused(
|
||||
positions,
|
||||
query,
|
||||
k_pe,
|
||||
kv_c,
|
||||
rope.cos_sin_cache,
|
||||
is_neox_style,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
kv_cache_dtype,
|
||||
kv_cache_scale,
|
||||
)
|
||||
|
||||
if kv_cache_dtype == "fp8":
|
||||
result_temp = torch.empty_like(kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
result_temp,
|
||||
kv_cache.contiguous(),
|
||||
kv_cache_scale.item(),
|
||||
kv_dtype=kv_cache_dtype,
|
||||
)
|
||||
expected_temp = torch.empty_like(ref_kv_cache, dtype=torch.float16)
|
||||
ops.convert_fp8(
|
||||
expected_temp, ref_kv_cache, kv_cache_scale.item(), kv_dtype=kv_cache_dtype
|
||||
)
|
||||
torch.testing.assert_close(result_temp, expected_temp, atol=0.001, rtol=0.1)
|
||||
else:
|
||||
torch.testing.assert_close(kv_cache, ref_kv_cache)
|
||||
|
||||
torch.testing.assert_close(
|
||||
query, ref_q_pe, atol=get_default_atol(query), rtol=get_default_rtol(query)
|
||||
)
|
||||
@@ -2418,6 +2418,32 @@ def concat_and_cache_mla(
|
||||
)
|
||||
|
||||
|
||||
def concat_and_cache_mla_rope_fused(
|
||||
positions: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c: torch.Tensor,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
kv_cache_scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla_rope_fused(
|
||||
positions,
|
||||
q_pe,
|
||||
k_pe,
|
||||
kv_c,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
kv_cache_dtype,
|
||||
kv_cache_scale,
|
||||
)
|
||||
|
||||
|
||||
def swap_blocks(
|
||||
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user