[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (#38460)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -508,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" int block_size_in_bytes, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
||||
|
||||
// Batch swap: submit all block copies in a single driver call.
|
||||
cache_ops.def(
|
||||
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
|
||||
" Tensor sizes) -> ()");
|
||||
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
|
||||
Reference in New Issue
Block a user