[Attention] MLA with chunked prefill (#12639)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Patrick Horn <patrick.horn@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Lucas Wilkinson
2025-02-21 18:30:12 -05:00
committed by GitHub
parent 900edbfa48
commit 288cc6c234
18 changed files with 1910 additions and 1275 deletions

View File

@@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cuda_utils.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
@@ -570,3 +571,161 @@ void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype);
}
}
namespace vllm {
// grid is launched with dimensions (batch, num_splits)
template <typename scalar_t>
__global__ void gather_cache(
const scalar_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE,
// ENTRIES...]
scalar_t* __restrict__ dst, // [TOT_TOKENS, ENTRIES...]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ cu_seq_lens, // [BATCH+1]
const int32_t block_size, const int32_t entry_size,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride,
const int32_t* __restrict__ seq_starts) { // Optional: starting offsets per
// batch
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = cu_seq_lens[bid];
const int32_t seq_end = cu_seq_lens[bid + 1];
const int32_t seq_len = seq_end - seq_start;
const int32_t tot_blocks = cuda_utils::ceil_div(seq_len, block_size);
const int32_t split_blocks = cuda_utils::ceil_div(tot_blocks, num_splits);
const int32_t split_start = split * split_blocks;
const int32_t split_end = min((split + 1) * split_blocks, tot_blocks);
const bool is_active_split = (split_start < tot_blocks);
const bool is_last_split = (split_end == tot_blocks);
if (!is_active_split) return;
int32_t full_blocks_end = split_end;
int32_t partial_block_size = 0;
// Adjust the pointer for the block_table for this batch.
// If seq_starts is provided, compute an offset based on (seq_starts[bid] /
// page_size)
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = 0;
if (seq_starts != nullptr) {
offset = seq_starts[bid] / block_size;
}
const int32_t* batch_block_table = block_table + batch_offset + offset;
// Adjust dst pointer based on the cumulative sequence lengths.
dst += seq_start * dst_entry_stride;
if (is_last_split) {
partial_block_size = seq_len % block_size;
if (partial_block_size) full_blocks_end -= 1;
}
auto copy_entry = [&](const scalar_t* __restrict__ _src,
scalar_t* __restrict__ _dst) {
for (int i = threadIdx.x; i < entry_size; i += blockDim.x)
_dst[i] = _src[i];
};
for (int pid = split_start; pid < full_blocks_end; ++pid) {
auto block_id = batch_block_table[pid];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + pid * block_size * dst_entry_stride;
for (int eid = 0; eid < block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}
if (partial_block_size) {
auto block_id = batch_block_table[full_blocks_end];
auto block_start_ptr = src_cache + block_id * cache_block_stride;
auto block_dst_ptr = dst + full_blocks_end * block_size * dst_entry_stride;
for (int eid = 0; eid < partial_block_size; ++eid) {
copy_entry(block_start_ptr + eid * cache_entry_stride,
block_dst_ptr + eid * dst_entry_stride);
}
}
}
} // namespace vllm
// Macro to dispatch the kernel based on the data type.
#define CALL_GATHER_CACHE(CPY_DTYPE) \
vllm::gather_cache<CPY_DTYPE><<<grid, block, 0, stream>>>( \
reinterpret_cast<CPY_DTYPE*>(src_cache.data_ptr()), \
reinterpret_cast<CPY_DTYPE*>(dst.data_ptr()), \
block_table.data_ptr<int32_t>(), cu_seq_lens.data_ptr<int32_t>(), \
block_size, entry_size, block_table_stride, cache_block_stride, \
cache_entry_stride, dst_entry_stride, seq_starts_ptr);
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
// - Optionally, seq_starts (if provided) offsets the starting block index by
// (seq_starts[bid] / page_size)
void gather_cache(
torch::Tensor const& src_cache, // [NUM_BLOCKS, BLOCK_SIZE, ENTRIES...]
torch::Tensor const& dst, // [TOT_TOKENS, ENTRIES...]
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size,
std::optional<torch::Tensor> seq_starts = std::nullopt) {
at::cuda::OptionalCUDAGuard device_guard(src_cache.device());
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int32_t block_size = src_cache.size(1);
int32_t entry_size = src_cache.flatten(2, -1).size(2);
TORCH_CHECK(block_table.dtype() == torch::kInt32,
"block_table must be int32");
TORCH_CHECK(cu_seq_lens.dtype() == torch::kInt32,
"cu_seq_lens must be int32");
if (seq_starts.has_value()) {
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
int64_t block_table_stride = block_table.stride(0);
int64_t cache_block_stride = src_cache.stride(0);
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size.
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(1024);
TORCH_CHECK(src_cache.dtype() == dst.dtype(),
"src_cache and dst must have the same dtype");
const int dtype_bits = src_cache.element_size() * 8;
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
if (dtype_bits == 32) {
CALL_GATHER_CACHE(uint32_t);
} else if (dtype_bits == 16) {
CALL_GATHER_CACHE(uint16_t);
} else if (dtype_bits == 8) {
CALL_GATHER_CACHE(uint8_t);
} else {
TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits);
}
}