[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (#38460)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Itay Etelis
2026-04-03 06:13:23 +03:00
committed by GitHub
parent 3bc2734dd0
commit 4a06e1246e
5 changed files with 118 additions and 15 deletions

View File

@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,

View File

@@ -24,6 +24,8 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include <cuda.h>
#endif
#if defined(__gfx942__)
@@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}
void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
if (n == 0) return;
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
#endif
}
namespace vllm {
// Grid: (num_layers, num_pairs)

View File

@@ -508,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"

View File

@@ -2641,6 +2641,22 @@ def swap_blocks(
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
def swap_blocks_batch(
src_ptrs: torch.Tensor,
dst_ptrs: torch.Tensor,
sizes: torch.Tensor,
) -> None:
"""
Batch version of swap_blocks: submit all copies in a single driver call.
Each entry specifies a raw pointer copy: src_ptrs[i] -> dst_ptrs[i]
of sizes[i] bytes. All three tensors must be int64 CPU tensors.
On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission
overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync.
"""
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
def convert_fp8(
output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
) -> None:

View File

@@ -149,6 +149,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
# list of CUDA events available for re-use
self._event_pool: list[torch.Event] = []
# Pre-compute base pointers and block sizes for batch copies.
self._src_base_ptrs = np.array(
[t.data_ptr() for t in self.src_tensors], dtype=np.int64
)
self._dst_base_ptrs = np.array(
[t.data_ptr() for t in self.dst_tensors], dtype=np.int64
)
self._block_size_in_bytes_arr = np.array(
self.tensor_block_size_in_bytes, dtype=np.int64
)
def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, BlockIDsLoadStoreSpec)
@@ -165,15 +176,35 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
src_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
dst_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
expand_block_ids(
src_blocks,
self.src_block_size_factor,
src_to_dst[:, 0],
src_block_ids,
skip_count=src_sub_blocks_to_skip,
)
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
src_to_dst_tensor = torch.from_numpy(src_to_dst)
expand_block_ids(dst_blocks, self.dst_block_size_factor, dst_block_ids)
# Build flat pointer arrays for all tensors × all block pairs.
num_pairs = dst_sub_block_count
num_tensors = len(self.src_tensors)
total = num_pairs * num_tensors
all_src = np.empty(total, dtype=np.int64)
all_dst = np.empty(total, dtype=np.int64)
all_sizes = np.empty(total, dtype=np.int64)
for t_idx, bsz in enumerate(self._block_size_in_bytes_arr):
start = t_idx * num_pairs
end = start + num_pairs
all_src[start:end] = self._src_base_ptrs[t_idx] + src_block_ids * bsz
all_dst[start:end] = self._dst_base_ptrs[t_idx] + dst_block_ids * bsz
all_sizes[start:end] = bsz
batch_src = torch.from_numpy(all_src)
batch_dst = torch.from_numpy(all_dst)
batch_sizes = torch.from_numpy(all_sizes)
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
start_event = (
@@ -197,17 +228,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
stream.wait_event(last_event)
with torch.cuda.stream(stream):
start_event.record(stream)
for src_tensor, dst_tensor, block_size_in_bytes in zip(
self.src_tensors,
self.dst_tensors,
self.tensor_block_size_in_bytes,
):
ops.swap_blocks(
src_tensor,
dst_tensor,
block_size_in_bytes,
src_to_dst_tensor,
)
if total > 0:
ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes)
end_event.record(stream)
self._transfer_events[job_id] = end_event