[NIXL] refine decoder side post process for heterogeneous BlockSize and kv_layout (#30275)

This commit is contained in:
Chendi.Xue
2026-01-09 15:22:19 -06:00
committed by GitHub
parent 2612ba9285
commit 94578127a4
2 changed files with 137 additions and 85 deletions

View File

@@ -203,6 +203,84 @@ def copy_kv_blocks(
copy_fn(src_tensor, dst_tensor, src_indices, dst_indices)
def kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio):
"""
Transforms the layout of received KV cache blocks to the local block_size.
(Only works for local blocksize > remote blocksize)
example:
local blocksize = 16 tokens, remote blocksize = 4 tokens
local block[0] = remote block[0, 1, 2, 3]
remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
local is |h0-b0..................|h1-b0..................|...
permute is to:
1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
2. permute => (H, nblocks, remoteN, D)
3. flatten => (H, localN, D)
"""
blocks_to_update = cache.index_select(0, indices)
# use physical order
blocks_to_update = blocks_to_update.permute(0, 2, 1, 3)
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
n_blocks = block_size_ratio
permuted_blocks = (
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
.permute(0, 2, 1, 3, 4)
.flatten(2, 3)
)
permuted_blocks = permuted_blocks.permute(0, 2, 1, 3)
cache.index_copy_(0, indices, permuted_blocks)
def kv_postprocess_layout_on_receive(cache, indices):
"""Transforms the layout of received KV cache blocks to the local format.
This method corrects layout mismatches from direct memory copies by
permuting the tensor dimensions.
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
"""
blocks_to_update = cache.index_select(0, indices)
target_shape = list(blocks_to_update.shape)
target_shape[0] = -1
inv_order = [0, 2, 1, 3]
src_shape = tuple(target_shape[i] for i in inv_order)
blocks_to_update = cache.index_select(0, indices)
permuted_blocks = blocks_to_update.reshape(src_shape).permute(*inv_order)
cache.index_copy_(0, indices, permuted_blocks)
def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio):
"""
Transforms the layout of received KV cache to the local block_size and HND.
(Only works for local blocksize > remote blocksize)
prefill is HND, smaller block_size
decode(local) is NHD, larger block_size
"""
blocks_to_update = cache.index_select(0, indices)
block_size, n_kv_heads, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
n_blocks = block_size_ratio
permuted_blocks = (
blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size)
.permute(0, 1, 3, 2, 4)
.flatten(1, 2)
)
cache.index_copy_(0, indices, permuted_blocks)
def yield_req_data(
scheduler_output,
) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:

View File

@@ -24,6 +24,9 @@ from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import (
EngineId,
TpKVTopology,
kv_postprocess_blksize_and_layout_on_receive,
kv_postprocess_blksize_on_receive,
kv_postprocess_layout_on_receive,
yield_req_data,
)
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
@@ -1749,88 +1752,62 @@ class NixlConnectorWorker:
"d2h",
)
def permute_device_kv(self, block_ids: list[int]):
"""Transforms the layout of received KV cache blocks to the local format.
def post_process_device_kv_on_receive(
self,
block_size_ratio: int,
block_ids_list: list[list[int]],
):
"""
Post process device kv cache after receiving from remote.
This method corrects layout mismatches from direct memory copies by
permuting the tensor dimensions.
- **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]`
- **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]`
Args:
block_ids: A list of block IDs to update and permute.
Implementation:
- x = blocks_to_update.reshape(src_shape) # view local kv with sender layout
- permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size
- cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back
3 types of post processing supported:
* kv_cache_postprocess_layout => convert from HND to NHD
* kv_cache_postprocess_blksize => convert from small block size
to large block size
* kv_cache_postprocess_blksize_and_layout => convert from small
block size to large block size and convert from HND to NHD
"""
split_k_and_v = self.kv_topo.split_k_and_v
inv_order = [0, 2, 1, 3]
sample_cache = list(self.device_kv_caches.values())[0][0]
target_shape = list(sample_cache.shape)
target_shape[0] = -1
src_shape = tuple(target_shape[i] for i in inv_order)
indices = torch.tensor(block_ids, device=sample_cache.device)
for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
blocks_to_update = cache.index_select(0, indices)
permuted_blocks = blocks_to_update.reshape(src_shape).permute(
*inv_order
)
cache.index_copy_(0, indices, permuted_blocks)
def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]):
def _process_local_gt_remote(blocks_to_update, block_size_ratio):
n_kv_heads, block_size, head_size = blocks_to_update.shape[1:]
remote_block_size = block_size // block_size_ratio
n_blocks = block_size_ratio
# actual permute is to convert
# for local blocksize > remote blocksize
# ex: local blocksize = 16 tokens, remote blocksize = 4 tokens
# local block[0] = remote block[0, 1, 2, 3]
# remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|...
# local is |h0-b0..................|h1-b0..................|...
# permute is to:
# 1. view => view remote as n_blocks * remote_shape(H,remoteN,D)
# 2. permute => (H, nblocks, remoteN, D)
# 3. flatten => (H, localN, D)
permuted_blocks = (
blocks_to_update.reshape(
-1, n_blocks, n_kv_heads, remote_block_size, head_size
)
.permute(0, 2, 1, 3, 4)
.flatten(2, 3)
)
return permuted_blocks
if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger and permuting layout from HND"
" to NHD.",
block_size_ratio,
)
elif self.enable_permute_local_kv:
logger.debug(
"Post-processing device kv cache on receive by permuting layout"
"from HND to NHD."
)
else:
logger.debug(
"Post-processing device kv cache on receive by converting "
"block_size with %sx bigger.",
block_size_ratio,
)
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
sample_cache = list(self.device_kv_caches.values())[0][0]
for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
assert block_size_ratio > 1, "Only nP < nD supported currently."
block_ids_list = [[item for sublist in block_ids_list for item in sublist]]
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=sample_cache.device)
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
blocks_to_update = cache.index_select(0, indices)
# because kv_cache is always using original layout NHD as
# virtual shape while stride can be either HND / NHD at
# initialization.
# we need to firstly get physical view of the tensor
permuted_blocks = _process_local_gt_remote(
blocks_to_update.permute(0, 2, 1, 3), block_size_ratio
).permute(0, 2, 1, 3)
cache.index_copy_(0, indices, permuted_blocks)
for _, cache_or_caches in self.device_kv_caches.items():
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
if self.enable_permute_local_kv and block_size_ratio > 1:
kv_postprocess_blksize_and_layout_on_receive(
cache, indices, block_size_ratio
)
elif self.enable_permute_local_kv:
kv_postprocess_layout_on_receive(cache, indices)
else:
kv_postprocess_blksize_on_receive(
cache, indices, block_size_ratio
)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
@@ -1854,7 +1831,6 @@ class NixlConnectorWorker:
len(done_recving),
)
block_ids_to_permute = []
block_ids_for_blocksize_post_process = defaultdict(list)
for req_id in done_recving:
# clean up metadata for completed requests
@@ -1863,24 +1839,22 @@ class NixlConnectorWorker:
assert meta.remote is not None
if self.use_host_buffer:
self.sync_recved_kv_to_device(req_id, meta)
if self.enable_permute_local_kv:
block_ids_to_permute += meta.local_physical_block_ids
# post processing for heteroblocksize
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
meta.remote.engine_id
)
if (
not self.use_mla
and block_size_ratio > 1
and self.kv_cache_layout == "HND"
if not self.use_mla and (
block_size_ratio > 1 or self.enable_permute_local_kv
):
block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_block_ids
meta.local_physical_block_ids
)
self.blocksize_post_process(block_ids_for_blocksize_post_process)
if len(block_ids_to_permute) > 0:
self.permute_device_kv(block_ids_to_permute)
for (
block_size_ratio,
block_ids_list,
) in block_ids_for_blocksize_post_process.items():
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()