OffloadingConnector: Support kernel_block_size != block_size (#30692)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include <vector>
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
int64_t block_size_in_bytes,
|
||||
const torch::Tensor& block_mapping);
|
||||
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
|
||||
@@ -25,6 +25,7 @@ typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
int64_t block_size_in_bytes,
|
||||
const torch::Tensor& block_mapping) {
|
||||
torch::Device src_device = src.device();
|
||||
torch::Device dst_device = dst.device();
|
||||
@@ -49,10 +50,6 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
|
||||
char* src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char* dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
// We use the stride instead of numel in case the cache is padded for memory
|
||||
// alignment reasons, we assume the blocks data (inclusive of any padding)
|
||||
// is contiguous in memory
|
||||
const int64_t block_size_in_bytes = src.element_size() * src.stride(0);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(
|
||||
src_device.is_cuda() ? src_device : dst_device);
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
@@ -692,7 +692,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
// Cache ops
|
||||
// Swap in (out) the cache blocks from src to dst.
|
||||
cache_ops.def(
|
||||
"swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()");
|
||||
"swap_blocks(Tensor src, Tensor! dst,"
|
||||
" int block_size_in_bytes, Tensor block_mapping) -> ()");
|
||||
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
|
||||
|
||||
// Reshape the key and value tensors and cache them.
|
||||
|
||||
@@ -405,19 +405,41 @@ def test_swap_blocks(
|
||||
|
||||
# Call the swap_blocks kernel.
|
||||
do_opcheck = head_size == HEAD_SIZES[0]
|
||||
src_cache = src_key_caches[0]
|
||||
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_key_caches[0], dist_key_caches[0], block_mapping_tensor),
|
||||
(
|
||||
src_key_caches[0],
|
||||
dist_key_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
),
|
||||
cond=do_opcheck,
|
||||
)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_value_caches[0], dist_value_caches[0], block_mapping_tensor),
|
||||
(
|
||||
src_value_caches[0],
|
||||
dist_value_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
),
|
||||
cond=do_opcheck,
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping_tensor)
|
||||
ops.swap_blocks(src_value_caches[0], dist_value_caches[0], block_mapping_tensor)
|
||||
ops.swap_blocks(
|
||||
src_key_caches[0],
|
||||
dist_key_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
)
|
||||
ops.swap_blocks(
|
||||
src_value_caches[0],
|
||||
dist_value_caches[0],
|
||||
block_size_in_bytes,
|
||||
block_mapping_tensor,
|
||||
)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
torch.testing.assert_close(
|
||||
@@ -723,13 +745,14 @@ def test_swap_blocks_mla(
|
||||
block_mapping, dtype=torch.int64, device="cpu"
|
||||
).view(-1, 2)
|
||||
|
||||
block_size_in_bytes = src_cache.element_size() * src_cache.stride(0)
|
||||
opcheck(
|
||||
torch.ops._C_cache_ops.swap_blocks,
|
||||
(src_cache, dst_cache, block_mapping_tensor),
|
||||
(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor),
|
||||
test_utils=DEFAULT_OPCHECK_TEST_UTILS,
|
||||
)
|
||||
|
||||
ops.swap_blocks(src_cache, dst_cache, block_mapping_tensor)
|
||||
ops.swap_blocks(src_cache, dst_cache, block_size_in_bytes, block_mapping_tensor)
|
||||
|
||||
for src, dst in block_mapping:
|
||||
torch.testing.assert_close(
|
||||
|
||||
@@ -25,8 +25,9 @@ if not current_platform.is_rocm():
|
||||
|
||||
NUM_GPU_BLOCKS = [64]
|
||||
NUM_CPU_BLOCKS = [256]
|
||||
GPU_BLOCK_SIZES = [16]
|
||||
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
|
||||
KERNEL_BLOCK_SIZES = [16]
|
||||
LOGICAL_BLOCK_SIZES = [16, 32]
|
||||
LOGICAL_BLOCKS_PER_CPU_BLOCK = [1, 3]
|
||||
HEAD_SIZES = [64]
|
||||
NUM_HEADS = [8]
|
||||
NUM_LAYERS = [4]
|
||||
@@ -40,8 +41,9 @@ NUM_MAPPINGS = [3]
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
|
||||
@pytest.mark.parametrize("kernel_block_size", KERNEL_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("logical_block_size", LOGICAL_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("logical_blocks_per_cpu_block", LOGICAL_BLOCKS_PER_CPU_BLOCK)
|
||||
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
|
||||
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@@ -55,8 +57,9 @@ def test_transfer(
|
||||
num_mappings: int,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
gpu_block_size: int,
|
||||
gpu_blocks_per_cpu_block: int,
|
||||
kernel_block_size: int,
|
||||
logical_block_size: int,
|
||||
logical_blocks_per_cpu_block: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
num_layers: int,
|
||||
@@ -69,6 +72,10 @@ def test_transfer(
|
||||
# create per-layer GPU KV caches based on available attn_backends
|
||||
attn_backends_list = BACKENDS_TO_TEST
|
||||
|
||||
assert logical_block_size % kernel_block_size == 0
|
||||
kernel_blocks_per_gpu_block = logical_block_size // kernel_block_size
|
||||
num_gpu_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
|
||||
|
||||
gpu_caches = {}
|
||||
attn_backends = {}
|
||||
for i in range(num_layers):
|
||||
@@ -78,15 +85,16 @@ def test_transfer(
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
gpu_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, gpu_block_size, num_heads, head_size
|
||||
num_gpu_kernel_blocks, kernel_block_size, num_heads, head_size
|
||||
)
|
||||
gpu_caches[layer_name] = torch.rand(gpu_cache_shape, dtype=dtype, device=device)
|
||||
|
||||
# create handler
|
||||
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
|
||||
cpu_block_size = logical_blocks_per_cpu_block * logical_block_size
|
||||
kernel_blocks_per_cpu_block = cpu_block_size // kernel_block_size
|
||||
handlers = CpuGpuOffloadingHandlers(
|
||||
attn_backends=attn_backends,
|
||||
gpu_block_size=gpu_block_size,
|
||||
gpu_block_size=logical_block_size,
|
||||
cpu_block_size=cpu_block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
gpu_caches=gpu_caches,
|
||||
@@ -94,22 +102,34 @@ def test_transfer(
|
||||
|
||||
# select block mappings
|
||||
gpu_blocks = random.sample(
|
||||
range(num_gpu_blocks), num_mappings * gpu_blocks_per_cpu_block
|
||||
range(num_gpu_blocks), num_mappings * logical_blocks_per_cpu_block
|
||||
)
|
||||
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
|
||||
|
||||
# convert cpu blocks to gpu block size
|
||||
cpu_blocks_in_gpu_block_size = []
|
||||
for cpu_block in cpu_blocks:
|
||||
base_block_id = cpu_block * gpu_blocks_per_cpu_block
|
||||
for i in range(gpu_blocks_per_cpu_block):
|
||||
cpu_blocks_in_gpu_block_size.append(i + base_block_id)
|
||||
# convert gpu blocks to kernel block size
|
||||
gpu_blocks_in_kernel_block_size = []
|
||||
for gpu_block in gpu_blocks:
|
||||
base_block_id = gpu_block * kernel_blocks_per_gpu_block
|
||||
for i in range(kernel_blocks_per_gpu_block):
|
||||
gpu_blocks_in_kernel_block_size.append(i + base_block_id)
|
||||
|
||||
# maybe skip a GPU block to test reading from the middle of a CPU block
|
||||
# convert cpu blocks to gpu block size
|
||||
cpu_blocks_in_kernel_block_size = []
|
||||
for cpu_block in cpu_blocks:
|
||||
base_block_id = cpu_block * kernel_blocks_per_cpu_block
|
||||
for i in range(kernel_blocks_per_cpu_block):
|
||||
cpu_blocks_in_kernel_block_size.append(i + base_block_id)
|
||||
|
||||
# maybe skip some GPU block to test reading from the middle of a CPU block
|
||||
if not gpu_to_cpu:
|
||||
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1 :]
|
||||
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
|
||||
gpu_blocks_per_cpu_block - 1 :
|
||||
gpu_blocks_to_skip = logical_blocks_per_cpu_block - 1
|
||||
gpu_blocks = gpu_blocks[gpu_blocks_to_skip:]
|
||||
kernel_blocks_to_skip = gpu_blocks_to_skip * kernel_blocks_per_gpu_block
|
||||
gpu_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size[
|
||||
kernel_blocks_to_skip:
|
||||
]
|
||||
cpu_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size[
|
||||
kernel_blocks_to_skip:
|
||||
]
|
||||
|
||||
# set transfer direction
|
||||
@@ -119,23 +139,23 @@ def test_transfer(
|
||||
dst_spec_class = CPULoadStoreSpec
|
||||
src_blocks = gpu_blocks
|
||||
dst_blocks = cpu_blocks
|
||||
src_blocks_in_gpu_block_size = gpu_blocks
|
||||
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
|
||||
src_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
|
||||
dst_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
|
||||
dst_size_in_kernel_blocks = num_cpu_blocks * kernel_blocks_per_cpu_block
|
||||
else:
|
||||
handler = handlers.cpu_to_gpu_handler
|
||||
src_spec_class = CPULoadStoreSpec
|
||||
dst_spec_class = GPULoadStoreSpec
|
||||
src_blocks = cpu_blocks
|
||||
dst_blocks = gpu_blocks
|
||||
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||
dst_blocks_in_gpu_block_size = gpu_blocks
|
||||
dst_size_in_gpu_blocks = num_gpu_blocks
|
||||
src_blocks_in_kernel_block_size = cpu_blocks_in_kernel_block_size
|
||||
dst_blocks_in_kernel_block_size = gpu_blocks_in_kernel_block_size
|
||||
dst_size_in_kernel_blocks = num_gpu_blocks * kernel_blocks_per_gpu_block
|
||||
|
||||
# build dst -> src mapping
|
||||
dst_to_src = {}
|
||||
for src_block, dst_block in zip(
|
||||
src_blocks_in_gpu_block_size, dst_blocks_in_gpu_block_size
|
||||
src_blocks_in_kernel_block_size, dst_blocks_in_kernel_block_size
|
||||
):
|
||||
dst_to_src[dst_block] = src_block
|
||||
|
||||
@@ -165,29 +185,15 @@ def test_transfer(
|
||||
assert torch.equal(orig_tensor, tensor)
|
||||
|
||||
# verify dst tensors
|
||||
for dst_block in range(dst_size_in_gpu_blocks):
|
||||
for dst_block in range(dst_size_in_kernel_blocks):
|
||||
src_block_candidate = dst_to_src.get(dst_block)
|
||||
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
|
||||
for src_cache, dst_cache, orig_dst_cache in zip(
|
||||
handler.src_tensors,
|
||||
handler.dst_tensors,
|
||||
orig_dst_caches,
|
||||
handler.kv_dim_before_num_blocks,
|
||||
):
|
||||
if kv_dim:
|
||||
# iterate over key, value
|
||||
for i in range(2):
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[i][src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[i][dst_block]
|
||||
torch.testing.assert_close(
|
||||
dst_cache[i][dst_block].cpu(), expected_value.cpu()
|
||||
)
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[src_block_candidate]
|
||||
else:
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[dst_block]
|
||||
torch.testing.assert_close(
|
||||
dst_cache[dst_block].cpu(), expected_value.cpu()
|
||||
)
|
||||
expected_value = orig_dst_cache[dst_block]
|
||||
torch.testing.assert_close(dst_cache[dst_block].cpu(), expected_value.cpu())
|
||||
|
||||
@@ -2455,9 +2455,32 @@ def concat_and_cache_mla_rope_fused(
|
||||
|
||||
|
||||
def swap_blocks(
|
||||
src: torch.Tensor, dst: torch.Tensor, block_mapping: torch.Tensor
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
block_size_in_bytes: int,
|
||||
block_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.swap_blocks(src, dst, block_mapping)
|
||||
"""
|
||||
Copy specific blocks from one tensor to another.
|
||||
|
||||
This method assumes each of the two input tensors is composed of
|
||||
consecutive contiguous blocks, of size block_size_in_bytes.
|
||||
i.e. the memory layout for each tensor is:
|
||||
[block0] [block1] ... [block N]
|
||||
|
||||
block_mapping determines the subset of blocks to copy of the source tensor,
|
||||
and their matching destination block number on the destination tensor.
|
||||
block_mapping is expected to be a tensor of shape (num_blocks_to_copy, 2)
|
||||
where each block_mapping[i] represents a single copy operation, copying
|
||||
block #block_mapping[i][0] from the source tensor
|
||||
to block #block_mapping[i][1] on the destination tensor.
|
||||
block_mapping should have dtype int64.
|
||||
|
||||
The source and the destination tensors can be either on cpu or gpu,
|
||||
but not both on cpu.
|
||||
the block mapping tensor must on cpu.
|
||||
"""
|
||||
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
|
||||
|
||||
|
||||
def convert_fp8(
|
||||
|
||||
@@ -65,7 +65,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
self,
|
||||
src_tensors: list[torch.Tensor],
|
||||
dst_tensors: list[torch.Tensor],
|
||||
kv_dim_before_num_blocks: list[bool],
|
||||
src_block_size_factor: int,
|
||||
dst_block_size_factor: int,
|
||||
):
|
||||
@@ -76,22 +75,23 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
src_tensors: list of KV cache tensors to copy from.
|
||||
dst_tensors: list of KV cache tensors to copy to.
|
||||
Order should match src_tensors.
|
||||
kv_dim_before_num_blocks: list of bools, indicating
|
||||
whether the respective KV cache tensor has a KV
|
||||
dimension before its num_blocks dimension.
|
||||
e.g. (2, num_blocks, ...)
|
||||
src_block_size_factor: The number of kernel blocks
|
||||
per KV block in a source tensor.
|
||||
dst_block_size_factor: The number of kernel blocks
|
||||
per KV block in a destination tensor.
|
||||
"""
|
||||
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
|
||||
assert len(src_tensors) == len(dst_tensors)
|
||||
|
||||
self.src_tensors: list[torch.Tensor] = src_tensors
|
||||
self.dst_tensors: list[torch.Tensor] = dst_tensors
|
||||
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
|
||||
self.src_block_size_factor: int = src_block_size_factor
|
||||
self.dst_block_size_factor: int = dst_block_size_factor
|
||||
min_block_size_factor = min(src_block_size_factor, dst_block_size_factor)
|
||||
self.src_block_size_factor: int = src_block_size_factor // min_block_size_factor
|
||||
self.dst_block_size_factor: int = dst_block_size_factor // min_block_size_factor
|
||||
|
||||
self.block_size_in_bytes = [
|
||||
tensor.element_size() * tensor.stride(0) * min_block_size_factor
|
||||
for tensor in src_tensors
|
||||
]
|
||||
|
||||
assert len(src_tensors) > 0
|
||||
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
|
||||
@@ -142,16 +142,17 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
# assure job will start only after the previous one completes
|
||||
stream.wait_event(last_event)
|
||||
with torch.cuda.stream(stream):
|
||||
for src_tensor, dst_tensor, kv_dim in zip(
|
||||
self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks
|
||||
for src_tensor, dst_tensor, block_size_in_bytes in zip(
|
||||
self.src_tensors,
|
||||
self.dst_tensors,
|
||||
self.block_size_in_bytes,
|
||||
):
|
||||
if kv_dim:
|
||||
src_key_cache, src_value_cache = src_tensor
|
||||
dst_key_cache, dst_value_cache = dst_tensor
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
|
||||
else:
|
||||
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
|
||||
ops.swap_blocks(
|
||||
src_tensor,
|
||||
dst_tensor,
|
||||
block_size_in_bytes,
|
||||
src_to_dst_tensor,
|
||||
)
|
||||
event.record(stream)
|
||||
|
||||
self._transfer_events[job_id] = event
|
||||
@@ -188,19 +189,12 @@ class CpuGpuOffloadingHandlers:
|
||||
):
|
||||
assert gpu_caches
|
||||
assert cpu_block_size % gpu_block_size == 0
|
||||
block_size_factor = cpu_block_size // gpu_block_size
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
# allocate cpu tensors
|
||||
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
|
||||
gpu_tensors: list[torch.Tensor] = []
|
||||
cpu_tensors: list[torch.Tensor] = []
|
||||
kv_dim_before_num_blocks: list[bool] = []
|
||||
# find kernel block size and determine layout per each gpu tensor
|
||||
kernel_block_size: int | None = None
|
||||
# list of (gpu_tensor, split_k_and_v)
|
||||
parsed_gpu_tensors: list[tuple[torch.Tensor, bool]] = []
|
||||
for layer_name, gpu_tensor in gpu_caches.items():
|
||||
gpu_tensors.append(gpu_tensor)
|
||||
|
||||
gpu_shape = gpu_tensor.shape
|
||||
attn_backend = attn_backends[layer_name]
|
||||
test_shape = attn_backend.get_kv_cache_shape(
|
||||
@@ -208,28 +202,20 @@ class CpuGpuOffloadingHandlers:
|
||||
)
|
||||
|
||||
has_layers_dim = False
|
||||
split_k_and_v = False
|
||||
if len(gpu_shape) != len(test_shape):
|
||||
# cross-layers tensor
|
||||
# shape is (num_blocks, ...)
|
||||
assert len(gpu_shape) == len(test_shape) + 1
|
||||
num_blocks_idx = 0
|
||||
has_layers_dim = True
|
||||
kv_dim_before_num_blocks.append(False)
|
||||
|
||||
# prepend a dummy num_layers=80 to test_shape
|
||||
test_shape = (80,) + test_shape
|
||||
elif test_shape[0] == 1234:
|
||||
# shape is (num_blocks, ...)
|
||||
num_blocks_idx = 0
|
||||
kv_dim_before_num_blocks.append(False)
|
||||
else:
|
||||
elif test_shape[0] != 1234:
|
||||
# shape should be (2, num_blocks, ...)
|
||||
assert test_shape[0] == 2
|
||||
assert test_shape[1] == 1234
|
||||
assert gpu_shape[0] == 2
|
||||
|
||||
num_blocks_idx = 1
|
||||
kv_dim_before_num_blocks.append(True)
|
||||
split_k_and_v = True
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
@@ -250,30 +236,36 @@ class CpuGpuOffloadingHandlers:
|
||||
kernel_block_size = gpu_shape[block_size_idx]
|
||||
assert gpu_block_size % kernel_block_size == 0
|
||||
|
||||
cpu_shape = list(gpu_shape)
|
||||
cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor
|
||||
|
||||
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
||||
cpu_tensors.append(
|
||||
torch.zeros(
|
||||
cpu_shape,
|
||||
dtype=gpu_tensor.dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
)
|
||||
parsed_gpu_tensors.append((gpu_tensor, split_k_and_v))
|
||||
|
||||
assert kernel_block_size is not None
|
||||
gpu_block_size_factor = gpu_block_size // kernel_block_size
|
||||
cpu_block_size_factor = cpu_block_size // kernel_block_size
|
||||
gpu_block_size_factor = gpu_block_size // kernel_block_size
|
||||
num_cpu_kernel_blocks = num_cpu_blocks * cpu_block_size_factor
|
||||
|
||||
# TODO (orozery): adapt swap_blocks to support gpu_block_size_factor
|
||||
assert gpu_block_size_factor == 1
|
||||
# allocate cpu tensors
|
||||
pin_memory = is_pin_memory_available()
|
||||
logger.info("Allocating %d CPU tensors...", len(parsed_gpu_tensors))
|
||||
gpu_tensors: list[torch.Tensor] = []
|
||||
cpu_tensors: list[torch.Tensor] = []
|
||||
for gpu_tensor, split_k_and_v in parsed_gpu_tensors:
|
||||
cpu_shape = list(gpu_tensor.shape)
|
||||
cpu_shape[1 if split_k_and_v else 0] = num_cpu_kernel_blocks
|
||||
|
||||
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
||||
cpu_tensor = torch.zeros(
|
||||
cpu_shape,
|
||||
dtype=gpu_tensor.dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
|
||||
gpu_tensors.extend(gpu_tensor.unbind(0) if split_k_and_v else [gpu_tensor])
|
||||
cpu_tensors.extend(cpu_tensor.unbind(0) if split_k_and_v else [cpu_tensor])
|
||||
|
||||
self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler(
|
||||
src_tensors=gpu_tensors,
|
||||
dst_tensors=cpu_tensors,
|
||||
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||
src_block_size_factor=gpu_block_size_factor,
|
||||
dst_block_size_factor=cpu_block_size_factor,
|
||||
)
|
||||
@@ -281,7 +273,6 @@ class CpuGpuOffloadingHandlers:
|
||||
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
|
||||
src_tensors=cpu_tensors,
|
||||
dst_tensors=gpu_tensors,
|
||||
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||
src_block_size_factor=cpu_block_size_factor,
|
||||
dst_block_size_factor=gpu_block_size_factor,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user