[NIXL] refine decoder side post process for heterogeneous BlockSize and kv_layout (#30275)
This commit is contained in:
@@ -203,6 +203,84 @@ def copy_kv_blocks(
|
||||
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
|
||||
|
||||
|
||||
def kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio):
|
||||
"""
|
||||
Transforms the layout of received KV cache blocks to the local block_size.
|
||||
(Only works for local blocksize > remote blocksize)
|
||||
|
||||
example:
|
||||
local blocksize = 16 tokens, remote blocksize = 4 tokens
|
||||
local block[0] = remote block[0, 1, 2, 3]
|
||||
remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
|
||||
local is |h0-b0..................|h1-b0..................|...
|
||||
permute is to:
|
||||
1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
|
||||
2. permute => (H, nblocks, remoteN, D)
|
||||
3. flatten => (H, localN, D)
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
# use physical order
|
||||
blocks_to_update = blocks_to_update.permute(0, 2, 1, 3)
|
||||
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
permuted_blocks = permuted_blocks.permute(0, 2, 1, 3)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def kv_postprocess_layout_on_receive(cache, indices):
|
||||
"""Transforms the layout of received KV cache blocks to the local format.
|
||||
|
||||
This method corrects layout mismatches from direct memory copies by
|
||||
permuting the tensor dimensions.
|
||||
|
||||
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
|
||||
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
|
||||
|
||||
Implementation:
|
||||
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
|
||||
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
|
||||
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
|
||||
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
target_shape = list(blocks_to_update.shape)
|
||||
target_shape[0] = -1
|
||||
inv_order = [0, 2, 1, 3]
|
||||
src_shape = tuple(target_shape[i] for i in inv_order)
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
permuted_blocks = blocks_to_update.reshape(src_shape).permute(*inv_order)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio):
|
||||
"""
|
||||
Transforms the layout of received KV cache to the local block_size and HND.
|
||||
(Only works for local blocksize > remote blocksize)
|
||||
|
||||
prefill is HND, smaller block_size
|
||||
decode(local) is NHD, larger block_size
|
||||
"""
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
|
||||
block_size, n_kv_heads, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
|
||||
.permute(0, 1, 3, 2, 4)
|
||||
.flatten(1, 2)
|
||||
)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output,
|
||||
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
|
||||
@@ -24,6 +24,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
EngineId,
|
||||
TpKVTopology,
|
||||
kv_postprocess_blksize_and_layout_on_receive,
|
||||
kv_postprocess_blksize_on_receive,
|
||||
kv_postprocess_layout_on_receive,
|
||||
yield_req_data,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
@@ -1749,88 +1752,62 @@ class NixlConnectorWorker:
|
||||
"d2h",
|
||||
)
|
||||
|
||||
def permute_device_kv(self, block_ids: list[int]):
|
||||
"""Transforms the layout of received KV cache blocks to the local format.
|
||||
def post_process_device_kv_on_receive(
|
||||
self,
|
||||
block_size_ratio: int,
|
||||
block_ids_list: list[list[int]],
|
||||
):
|
||||
"""
|
||||
Post process device kv cache after receiving from remote.
|
||||
|
||||
This method corrects layout mismatches from direct memory copies by
|
||||
permuting the tensor dimensions.
|
||||
|
||||
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
|
||||
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
|
||||
|
||||
Args:
|
||||
block_ids: A list of block IDs to update and permute.
|
||||
|
||||
Implementation:
|
||||
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
|
||||
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
|
||||
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
|
||||
3 types of post processing supported:
|
||||
* kv_cache_postprocess_layout => convert from HND to NHD
|
||||
* kv_cache_postprocess_blksize => convert from small block size
|
||||
to large block size
|
||||
* kv_cache_postprocess_blksize_and_layout => convert from small
|
||||
block size to large block size and convert from HND to NHD
|
||||
|
||||
"""
|
||||
split_k_and_v = self.kv_topo.split_k_and_v
|
||||
inv_order = [0, 2, 1, 3]
|
||||
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||
target_shape = list(sample_cache.shape)
|
||||
target_shape[0] = -1
|
||||
src_shape = tuple(target_shape[i] for i in inv_order)
|
||||
indices = torch.tensor(block_ids, device=sample_cache.device)
|
||||
|
||||
for _, cache_or_caches in self.device_kv_caches.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
for cache in cache_list:
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
permuted_blocks = blocks_to_update.reshape(src_shape).permute(
|
||||
*inv_order
|
||||
)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
|
||||
def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
|
||||
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
|
||||
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
|
||||
remote_block_size = block_size // block_size_ratio
|
||||
n_blocks = block_size_ratio
|
||||
# actual permute is to convert
|
||||
# for local blocksize > remote blocksize
|
||||
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
|
||||
# local block[0] = remote block[0, 1, 2, 3]
|
||||
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
|
||||
# local is |h0-b0..................|h1-b0..................|...
|
||||
# permute is to:
|
||||
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
|
||||
# 2. permute => (H, nblocks, remoteN, D)
|
||||
# 3. flatten => (H, localN, D)
|
||||
permuted_blocks = (
|
||||
blocks_to_update.reshape(
|
||||
-1, n_blocks, n_kv_heads, remote_block_size, head_size
|
||||
)
|
||||
.permute(0, 2, 1, 3, 4)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
return permuted_blocks
|
||||
|
||||
if len(self.device_kv_caches) == 0:
|
||||
return
|
||||
assert block_size_ratio >= 1, "Only nP < nD supported currently."
|
||||
if self.enable_permute_local_kv and block_size_ratio > 1:
|
||||
logger.debug(
|
||||
"Post-processing device kv cache on receive by converting "
|
||||
"block_size with %sx bigger and permuting layout from HND"
|
||||
" to NHD.",
|
||||
block_size_ratio,
|
||||
)
|
||||
elif self.enable_permute_local_kv:
|
||||
logger.debug(
|
||||
"Post-processing device kv cache on receive by permuting layout"
|
||||
"from HND to NHD."
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Post-processing device kv cache on receive by converting "
|
||||
"block_size with %sx bigger.",
|
||||
block_size_ratio,
|
||||
)
|
||||
|
||||
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
|
||||
sample_cache = list(self.device_kv_caches.values())[0][0]
|
||||
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
|
||||
assert block_size_ratio > 1, "Only nP < nD supported currently."
|
||||
block_ids_list = [[item for sublist in block_ids_list for item in sublist]]
|
||||
|
||||
for block_ids in block_ids_list:
|
||||
indices = torch.tensor(block_ids, device=sample_cache.device)
|
||||
for block_ids in block_ids_list:
|
||||
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
|
||||
|
||||
for _, cache_or_caches in self.device_kv_caches.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
for cache in cache_list:
|
||||
blocks_to_update = cache.index_select(0, indices)
|
||||
# because kv_cache is always using original layout NHD as
|
||||
# virtual shape while stride can be either HND / NHD at
|
||||
# initialization.
|
||||
# we need to firstly get physical view of the tensor
|
||||
permuted_blocks = _process_local_gt_remote(
|
||||
blocks_to_update.permute(0, 2, 1, 3), block_size_ratio
|
||||
).permute(0, 2, 1, 3)
|
||||
cache.index_copy_(0, indices, permuted_blocks)
|
||||
for _, cache_or_caches in self.device_kv_caches.items():
|
||||
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
|
||||
for cache in cache_list:
|
||||
if self.enable_permute_local_kv and block_size_ratio > 1:
|
||||
kv_postprocess_blksize_and_layout_on_receive(
|
||||
cache, indices, block_size_ratio
|
||||
)
|
||||
elif self.enable_permute_local_kv:
|
||||
kv_postprocess_layout_on_receive(cache, indices)
|
||||
else:
|
||||
kv_postprocess_blksize_on_receive(
|
||||
cache, indices, block_size_ratio
|
||||
)
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
@@ -1854,7 +1831,6 @@ class NixlConnectorWorker:
|
||||
len(done_recving),
|
||||
)
|
||||
|
||||
block_ids_to_permute = []
|
||||
block_ids_for_blocksize_post_process = defaultdict(list)
|
||||
for req_id in done_recving:
|
||||
# clean up metadata for completed requests
|
||||
@@ -1863,24 +1839,22 @@ class NixlConnectorWorker:
|
||||
assert meta.remote is not None
|
||||
if self.use_host_buffer:
|
||||
self.sync_recved_kv_to_device(req_id, meta)
|
||||
if self.enable_permute_local_kv:
|
||||
block_ids_to_permute += meta.local_physical_block_ids
|
||||
|
||||
# post processing for heteroblocksize
|
||||
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
|
||||
meta.remote.engine_id
|
||||
)
|
||||
if (
|
||||
not self.use_mla
|
||||
and block_size_ratio > 1
|
||||
and self.kv_cache_layout == "HND"
|
||||
if not self.use_mla and (
|
||||
block_size_ratio > 1 or self.enable_permute_local_kv
|
||||
):
|
||||
block_ids_for_blocksize_post_process[block_size_ratio].append(
|
||||
meta.local_block_ids
|
||||
meta.local_physical_block_ids
|
||||
)
|
||||
self.blocksize_post_process(block_ids_for_blocksize_post_process)
|
||||
if len(block_ids_to_permute) > 0:
|
||||
self.permute_device_kv(block_ids_to_permute)
|
||||
for (
|
||||
block_size_ratio,
|
||||
block_ids_list,
|
||||
) in block_ids_for_blocksize_post_process.items():
|
||||
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user