[ROCm] Fix KV copy methods and auto-select attention backend for ROCm (#36845)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -851,6 +851,30 @@ class RocmPlatform(Platform):
|
||||
"`dtype` flag in CLI, for example: --dtype=half."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def insert_blocks_to_device(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from src_cache to dst_cache on GPU."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.to(dst_cache.device)
|
||||
|
||||
@classmethod
|
||||
def swap_out_blocks_to_host(
|
||||
cls,
|
||||
src_cache: torch.Tensor,
|
||||
dst_cache: torch.Tensor,
|
||||
src_block_indices: torch.Tensor,
|
||||
dst_block_indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""Copy blocks from GPU to host (CPU)."""
|
||||
_src_cache = src_cache[:, src_block_indices]
|
||||
dst_cache[:, dst_block_indices] = _src_cache.cpu()
|
||||
|
||||
@classmethod
|
||||
def support_hybrid_kv_cache(cls) -> bool:
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user