[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (#38460)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Itay Etelis
2026-04-03 06:13:23 +03:00
committed by GitHub
parent 3bc2734dd0
commit 4a06e1246e
5 changed files with 118 additions and 15 deletions

View File

@@ -508,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"