[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
@@ -450,6 +450,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
"Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks", torch::kCUDA, ©_blocks);
|
||||
|
||||
cache_ops.def(
|
||||
"copy_blocks_mla(Tensor(a!)[] kv_caches, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("copy_blocks_mla", torch::kCUDA, ©_blocks_mla);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
cache_ops.def(
|
||||
"reshape_and_cache(Tensor key, Tensor value,"
|
||||
|
||||
Reference in New Issue
Block a user