[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (#38460)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
int64_t block_size_in_bytes,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
|
||||
@@ -24,6 +24,8 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#endif
|
||||
|
||||
#if defined(__gfx942__)
|
||||
@@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
}
|
||||
}
|
||||
|
||||
void swap_blocks_batch(const torch::Tensor& src_ptrs,
|
||||
const torch::Tensor& dst_ptrs,
|
||||
const torch::Tensor& sizes) {
|
||||
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
|
||||
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
|
||||
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
|
||||
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
|
||||
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
|
||||
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
|
||||
|
||||
const int64_t n = src_ptrs.size(0);
|
||||
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
|
||||
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
|
||||
|
||||
if (n == 0) return;
|
||||
|
||||
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
|
||||
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
|
||||
const int64_t* size_data = sizes.data_ptr<int64_t>();
|
||||
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
|
||||
// driver call, amortizing per-copy submission overhead.
|
||||
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
|
||||
// so we reinterpret_cast the tensor data directly to avoid copies.
|
||||
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
|
||||
static_assert(sizeof(size_t) == sizeof(int64_t));
|
||||
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
|
||||
CUmemcpyAttributes attr = {};
|
||||
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
|
||||
size_t attrs_idx = 0;
|
||||
size_t fail_idx = 0;
|
||||
CUresult result = cuMemcpyBatchAsync(
|
||||
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
|
||||
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
|
||||
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
|
||||
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
|
||||
static_cast<CUstream>(stream));
|
||||
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
|
||||
fail_idx, " with error ", result);
|
||||
#else
|
||||
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
|
||||
// cudaMemcpyDefault lets the driver infer direction from pointer types.
|
||||
for (int64_t i = 0; i < n; i++) {
|
||||
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
|
||||
reinterpret_cast<void*>(src_data[i]),
|
||||
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
|
||||
stream);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
|
||||
@@ -508,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" int block_size_in_bytes, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
||||
|
||||
// Batch swap: submit all block copies in a single driver call.
|
||||
cache_ops.def(
|
||||
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
|
||||
" Tensor sizes) -> ()");
|
||||
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
|
||||
@@ -2641,6 +2641,22 @@ def swap_blocks(
|
||||
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
|
||||
|
||||
|
||||
def swap_blocks_batch(
|
||||
src_ptrs: torch.Tensor,
|
||||
dst_ptrs: torch.Tensor,
|
||||
sizes: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Batch version of swap_blocks: submit all copies in a single driver call.
|
||||
|
||||
Each entry specifies a raw pointer copy: src_ptrs[i] -> dst_ptrs[i]
|
||||
of sizes[i] bytes. All three tensors must be int64 CPU tensors.
|
||||
On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission
|
||||
overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync.
|
||||
"""
|
||||
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
|
||||
|
||||
|
||||
def convert_fp8(
|
||||
output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
|
||||
) -> None:
|
||||
|
||||
@@ -149,6 +149,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
# list of CUDA events available for re-use
|
||||
self._event_pool: list[torch.Event] = []
|
||||
|
||||
# Pre-compute base pointers and block sizes for batch copies.
|
||||
self._src_base_ptrs = np.array(
|
||||
[t.data_ptr() for t in self.src_tensors], dtype=np.int64
|
||||
)
|
||||
self._dst_base_ptrs = np.array(
|
||||
[t.data_ptr() for t in self.dst_tensors], dtype=np.int64
|
||||
)
|
||||
self._block_size_in_bytes_arr = np.array(
|
||||
self.tensor_block_size_in_bytes, dtype=np.int64
|
||||
)
|
||||
|
||||
def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
|
||||
src_spec, dst_spec = transfer_spec
|
||||
assert isinstance(src_spec, BlockIDsLoadStoreSpec)
|
||||
@@ -165,15 +176,35 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
|
||||
assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
|
||||
|
||||
src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
|
||||
src_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
|
||||
dst_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
|
||||
expand_block_ids(
|
||||
src_blocks,
|
||||
self.src_block_size_factor,
|
||||
src_to_dst[:, 0],
|
||||
src_block_ids,
|
||||
skip_count=src_sub_blocks_to_skip,
|
||||
)
|
||||
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
|
||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||
expand_block_ids(dst_blocks, self.dst_block_size_factor, dst_block_ids)
|
||||
|
||||
# Build flat pointer arrays for all tensors × all block pairs.
|
||||
num_pairs = dst_sub_block_count
|
||||
num_tensors = len(self.src_tensors)
|
||||
total = num_pairs * num_tensors
|
||||
|
||||
all_src = np.empty(total, dtype=np.int64)
|
||||
all_dst = np.empty(total, dtype=np.int64)
|
||||
all_sizes = np.empty(total, dtype=np.int64)
|
||||
|
||||
for t_idx, bsz in enumerate(self._block_size_in_bytes_arr):
|
||||
start = t_idx * num_pairs
|
||||
end = start + num_pairs
|
||||
all_src[start:end] = self._src_base_ptrs[t_idx] + src_block_ids * bsz
|
||||
all_dst[start:end] = self._dst_base_ptrs[t_idx] + dst_block_ids * bsz
|
||||
all_sizes[start:end] = bsz
|
||||
|
||||
batch_src = torch.from_numpy(all_src)
|
||||
batch_dst = torch.from_numpy(all_dst)
|
||||
batch_sizes = torch.from_numpy(all_sizes)
|
||||
|
||||
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
|
||||
start_event = (
|
||||
@@ -197,17 +228,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
stream.wait_event(last_event)
|
||||
with torch.cuda.stream(stream):
|
||||
start_event.record(stream)
|
||||
for src_tensor, dst_tensor, block_size_in_bytes in zip(
|
||||
self.src_tensors,
|
||||
self.dst_tensors,
|
||||
self.tensor_block_size_in_bytes,
|
||||
):
|
||||
ops.swap_blocks(
|
||||
src_tensor,
|
||||
dst_tensor,
|
||||
block_size_in_bytes,
|
||||
src_to_dst_tensor,
|
||||
)
|
||||
if total > 0:
|
||||
ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes)
|
||||
end_event.record(stream)
|
||||
|
||||
self._transfer_events[job_id] = end_event
|
||||
|
||||
Reference in New Issue
Block a user