diff --git a/csrc/cache.h b/csrc/cache.h index 0188a568e..821d5e719 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, int64_t block_size_in_bytes, const torch::Tensor& block_mapping); +void swap_blocks_batch(const torch::Tensor& src_ptrs, + const torch::Tensor& dst_ptrs, + const torch::Tensor& sizes); + void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 2b3906df9..c59da4379 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -24,6 +24,8 @@ #ifdef USE_ROCM #include typedef __hip_bfloat16 __nv_bfloat16; +#else + #include #endif #if defined(__gfx942__) @@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst, } } +void swap_blocks_batch(const torch::Tensor& src_ptrs, + const torch::Tensor& dst_ptrs, + const torch::Tensor& sizes) { + TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU"); + TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU"); + TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU"); + TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64"); + TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64"); + TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64"); + + const int64_t n = src_ptrs.size(0); + TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs"); + TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs"); + + if (n == 0) return; + + const int64_t* src_data = src_ptrs.data_ptr(); + const int64_t* dst_data = dst_ptrs.data_ptr(); + const int64_t* size_data = sizes.data_ptr(); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single + // driver call, amortizing per-copy submission overhead. + // int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms, + // so we reinterpret_cast the tensor data directly to avoid copies. + static_assert(sizeof(CUdeviceptr) == sizeof(int64_t)); + static_assert(sizeof(size_t) == sizeof(int64_t)); +#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + CUmemcpyAttributes attr = {}; + attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM; + size_t attrs_idx = 0; + size_t fail_idx = 0; + CUresult result = cuMemcpyBatchAsync( + reinterpret_cast(const_cast(dst_data)), + reinterpret_cast(const_cast(src_data)), + reinterpret_cast(const_cast(size_data)), + static_cast(n), &attr, &attrs_idx, 1, &fail_idx, + static_cast(stream)); + TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ", + fail_idx, " with error ", result); +#else + // Fallback for CUDA < 12.8 and ROCm: individual async copies. + // cudaMemcpyDefault lets the driver infer direction from pointer types. + for (int64_t i = 0; i < n; i++) { + cudaMemcpyAsync(reinterpret_cast(dst_data[i]), + reinterpret_cast(src_data[i]), + static_cast(size_data[i]), cudaMemcpyDefault, + stream); + } +#endif +} + namespace vllm { // Grid: (num_layers, num_pairs) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1beab5257..3593f1d22 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -508,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " int block_size_in_bytes, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks); + // Batch swap: submit all block copies in a single driver call. + cache_ops.def( + "swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs," + " Tensor sizes) -> ()"); + cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch); + // Reshape the key and value tensors and cache them. cache_ops.def( "reshape_and_cache(Tensor key, Tensor value," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 65eca3208..0c2a53ec0 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2641,6 +2641,22 @@ def swap_blocks( torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping) +def swap_blocks_batch( + src_ptrs: torch.Tensor, + dst_ptrs: torch.Tensor, + sizes: torch.Tensor, +) -> None: + """ + Batch version of swap_blocks: submit all copies in a single driver call. + + Each entry specifies a raw pointer copy: src_ptrs[i] -> dst_ptrs[i] + of sizes[i] bytes. All three tensors must be int64 CPU tensors. + On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission + overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync. + """ + torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes) + + def convert_fp8( output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8" ) -> None: diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index eeabf0cda..cd0136d48 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -149,6 +149,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): # list of CUDA events available for re-use self._event_pool: list[torch.Event] = [] + # Pre-compute base pointers and block sizes for batch copies. + self._src_base_ptrs = np.array( + [t.data_ptr() for t in self.src_tensors], dtype=np.int64 + ) + self._dst_base_ptrs = np.array( + [t.data_ptr() for t in self.dst_tensors], dtype=np.int64 + ) + self._block_size_in_bytes_arr = np.array( + self.tensor_block_size_in_bytes, dtype=np.int64 + ) + def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool: src_spec, dst_spec = transfer_spec assert isinstance(src_spec, BlockIDsLoadStoreSpec) @@ -165,15 +176,35 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip - src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64) + src_block_ids = np.empty(dst_sub_block_count, dtype=np.int64) + dst_block_ids = np.empty(dst_sub_block_count, dtype=np.int64) expand_block_ids( src_blocks, self.src_block_size_factor, - src_to_dst[:, 0], + src_block_ids, skip_count=src_sub_blocks_to_skip, ) - expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1]) - src_to_dst_tensor = torch.from_numpy(src_to_dst) + expand_block_ids(dst_blocks, self.dst_block_size_factor, dst_block_ids) + + # Build flat pointer arrays for all tensors × all block pairs. + num_pairs = dst_sub_block_count + num_tensors = len(self.src_tensors) + total = num_pairs * num_tensors + + all_src = np.empty(total, dtype=np.int64) + all_dst = np.empty(total, dtype=np.int64) + all_sizes = np.empty(total, dtype=np.int64) + + for t_idx, bsz in enumerate(self._block_size_in_bytes_arr): + start = t_idx * num_pairs + end = start + num_pairs + all_src[start:end] = self._src_base_ptrs[t_idx] + src_block_ids * bsz + all_dst[start:end] = self._dst_base_ptrs[t_idx] + dst_block_ids * bsz + all_sizes[start:end] = bsz + + batch_src = torch.from_numpy(all_src) + batch_dst = torch.from_numpy(all_dst) + batch_sizes = torch.from_numpy(all_sizes) stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream() start_event = ( @@ -197,17 +228,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): stream.wait_event(last_event) with torch.cuda.stream(stream): start_event.record(stream) - for src_tensor, dst_tensor, block_size_in_bytes in zip( - self.src_tensors, - self.dst_tensors, - self.tensor_block_size_in_bytes, - ): - ops.swap_blocks( - src_tensor, - dst_tensor, - block_size_in_bytes, - src_to_dst_tensor, - ) + if total > 0: + ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes) end_event.record(stream) self._transfer_events[job_id] = end_event