[Perf] Mem align KV caches for CUDA devices (MLA perf improvement) (#12676)
Signed-off-by: simon-mo <xmo@berkeley.edu> Signed-off-by: Lucas Wilkinson <lcwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Co-authored-by: simon-mo <xmo@berkeley.edu>
This commit is contained in:
@@ -15,6 +15,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
std::vector<torch::Tensor> const& value_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
|
||||
Reference in New Issue
Block a user