[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (#38460)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Itay Etelis
2026-04-03 06:13:23 +03:00
committed by GitHub
parent 3bc2734dd0
commit 4a06e1246e
5 changed files with 118 additions and 15 deletions

View File

@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,

View File

@@ -24,6 +24,8 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include <cuda.h>
#endif
#if defined(__gfx942__)
@@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
if (n == 0) return;
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
#endif
}
namespace vllm {
// Grid: (num_layers, num_pairs)

View File

@@ -508,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"