diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 1a09f2e6b..39d3085ba 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -203,6 +203,84 @@ def copy_kv_blocks( copy_fn(src_tensor, dst_tensor, src_indices, dst_indices) +def kv_postprocess_blksize_on_receive(cache, indices, block_size_ratio): + """ + Transforms the layout of received KV cache blocks to the local block_size. + (Only works for local blocksize > remote blocksize) + + example: + local blocksize = 16 tokens, remote blocksize = 4 tokens + local block[0] = remote block[0, 1, 2, 3] + remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... + local is |h0-b0..................|h1-b0..................|... + permute is to: + 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) + 2. permute => (H, nblocks, remoteN, D) + 3. flatten => (H, localN, D) + """ + blocks_to_update = cache.index_select(0, indices) + # use physical order + blocks_to_update = blocks_to_update.permute(0, 2, 1, 3) + n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] + remote_block_size = block_size // block_size_ratio + n_blocks = block_size_ratio + + permuted_blocks = ( + blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size) + .permute(0, 2, 1, 3, 4) + .flatten(2, 3) + ) + permuted_blocks = permuted_blocks.permute(0, 2, 1, 3) + cache.index_copy_(0, indices, permuted_blocks) + + +def kv_postprocess_layout_on_receive(cache, indices): + """Transforms the layout of received KV cache blocks to the local format. + + This method corrects layout mismatches from direct memory copies by + permuting the tensor dimensions. + + - **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]` + - **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]` + + Implementation: + - x = blocks_to_update.reshape(src_shape) # view local kv with sender layout + - permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size + - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back + + """ + blocks_to_update = cache.index_select(0, indices) + target_shape = list(blocks_to_update.shape) + target_shape[0] = -1 + inv_order = [0, 2, 1, 3] + src_shape = tuple(target_shape[i] for i in inv_order) + blocks_to_update = cache.index_select(0, indices) + permuted_blocks = blocks_to_update.reshape(src_shape).permute(*inv_order) + cache.index_copy_(0, indices, permuted_blocks) + + +def kv_postprocess_blksize_and_layout_on_receive(cache, indices, block_size_ratio): + """ + Transforms the layout of received KV cache to the local block_size and HND. + (Only works for local blocksize > remote blocksize) + + prefill is HND, smaller block_size + decode(local) is NHD, larger block_size + """ + blocks_to_update = cache.index_select(0, indices) + + block_size, n_kv_heads, head_size = blocks_to_update.shape[1:] + remote_block_size = block_size // block_size_ratio + n_blocks = block_size_ratio + + permuted_blocks = ( + blocks_to_update.reshape(-1, n_blocks, n_kv_heads, remote_block_size, head_size) + .permute(0, 1, 3, 2, 4) + .flatten(1, 2) + ) + cache.index_copy_(0, indices, permuted_blocks) + + def yield_req_data( scheduler_output, ) -> Iterator[tuple[str, tuple[list[int], ...], bool]]: diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index dc50ea678..131cb3ec9 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -24,6 +24,9 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( EngineId, TpKVTopology, + kv_postprocess_blksize_and_layout_on_receive, + kv_postprocess_blksize_on_receive, + kv_postprocess_layout_on_receive, yield_req_data, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( @@ -1749,88 +1752,62 @@ class NixlConnectorWorker: "d2h", ) - def permute_device_kv(self, block_ids: list[int]): - """Transforms the layout of received KV cache blocks to the local format. + def post_process_device_kv_on_receive( + self, + block_size_ratio: int, + block_ids_list: list[list[int]], + ): + """ + Post process device kv cache after receiving from remote. - This method corrects layout mismatches from direct memory copies by - permuting the tensor dimensions. - - - **Source Layout:** `[num_blocks, n_kv_head, block_size, head_dim]` - - **Target Layout:** `[num_blocks, block_size, n_kv_head, head_dim]` - - Args: - block_ids: A list of block IDs to update and permute. - - Implementation: - - x = blocks_to_update.reshape(src_shape) # view local kv with sender layout - - permuted_blocks = x.permute(*inv_order) # transpose n_kv_heads, block_size - - cache.index_copy_(0, indices, permuted_blocks) # copy permuted kv back + 3 types of post processing supported: + * kv_cache_postprocess_layout => convert from HND to NHD + * kv_cache_postprocess_blksize => convert from small block size + to large block size + * kv_cache_postprocess_blksize_and_layout => convert from small + block size to large block size and convert from HND to NHD """ - split_k_and_v = self.kv_topo.split_k_and_v - inv_order = [0, 2, 1, 3] - sample_cache = list(self.device_kv_caches.values())[0][0] - target_shape = list(sample_cache.shape) - target_shape[0] = -1 - src_shape = tuple(target_shape[i] for i in inv_order) - indices = torch.tensor(block_ids, device=sample_cache.device) - - for _, cache_or_caches in self.device_kv_caches.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] - for cache in cache_list: - blocks_to_update = cache.index_select(0, indices) - permuted_blocks = blocks_to_update.reshape(src_shape).permute( - *inv_order - ) - cache.index_copy_(0, indices, permuted_blocks) - - def blocksize_post_process(self, block_ids_per_ratio: dict[int, list[list[int]]]): - def _process_local_gt_remote(blocks_to_update, block_size_ratio): - n_kv_heads, block_size, head_size = blocks_to_update.shape[1:] - remote_block_size = block_size // block_size_ratio - n_blocks = block_size_ratio - # actual permute is to convert - # for local blocksize > remote blocksize - # ex: local blocksize = 16 tokens, remote blocksize = 4 tokens - # local block[0] = remote block[0, 1, 2, 3] - # remote is |h0-b0|h1-b0|h2-b0|h3-b0|h0-b1|h1-b1|h2-b1|h3-b1|... - # local is |h0-b0..................|h1-b0..................|... - # permute is to: - # 1. view => view remote as n_blocks * remote_shape(H,remoteN,D) - # 2. permute => (H, nblocks, remoteN, D) - # 3. flatten => (H, localN, D) - permuted_blocks = ( - blocks_to_update.reshape( - -1, n_blocks, n_kv_heads, remote_block_size, head_size - ) - .permute(0, 2, 1, 3, 4) - .flatten(2, 3) - ) - return permuted_blocks - if len(self.device_kv_caches) == 0: return + assert block_size_ratio >= 1, "Only nP < nD supported currently." + if self.enable_permute_local_kv and block_size_ratio > 1: + logger.debug( + "Post-processing device kv cache on receive by converting " + "block_size with %sx bigger and permuting layout from HND" + " to NHD.", + block_size_ratio, + ) + elif self.enable_permute_local_kv: + logger.debug( + "Post-processing device kv cache on receive by permuting layout" + "from HND to NHD." + ) + else: + logger.debug( + "Post-processing device kv cache on receive by converting " + "block_size with %sx bigger.", + block_size_ratio, + ) + split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first) - sample_cache = list(self.device_kv_caches.values())[0][0] - for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): - assert block_size_ratio > 1, "Only nP < nD supported currently." - block_ids_list = [[item for sublist in block_ids_list for item in sublist]] - for block_ids in block_ids_list: - indices = torch.tensor(block_ids, device=sample_cache.device) + for block_ids in block_ids_list: + indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) - for _, cache_or_caches in self.device_kv_caches.items(): - cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] - for cache in cache_list: - blocks_to_update = cache.index_select(0, indices) - # because kv_cache is always using original layout NHD as - # virtual shape while stride can be either HND / NHD at - # initialization. - # we need to firstly get physical view of the tensor - permuted_blocks = _process_local_gt_remote( - blocks_to_update.permute(0, 2, 1, 3), block_size_ratio - ).permute(0, 2, 1, 3) - cache.index_copy_(0, indices, permuted_blocks) + for _, cache_or_caches in self.device_kv_caches.items(): + cache_list = cache_or_caches if split_k_and_v else [cache_or_caches] + for cache in cache_list: + if self.enable_permute_local_kv and block_size_ratio > 1: + kv_postprocess_blksize_and_layout_on_receive( + cache, indices, block_size_ratio + ) + elif self.enable_permute_local_kv: + kv_postprocess_layout_on_receive(cache, indices) + else: + kv_postprocess_blksize_on_receive( + cache, indices, block_size_ratio + ) def get_finished(self) -> tuple[set[str], set[str]]: """ @@ -1854,7 +1831,6 @@ class NixlConnectorWorker: len(done_recving), ) - block_ids_to_permute = [] block_ids_for_blocksize_post_process = defaultdict(list) for req_id in done_recving: # clean up metadata for completed requests @@ -1863,24 +1839,22 @@ class NixlConnectorWorker: assert meta.remote is not None if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) - if self.enable_permute_local_kv: - block_ids_to_permute += meta.local_physical_block_ids # post processing for heteroblocksize block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id( meta.remote.engine_id ) - if ( - not self.use_mla - and block_size_ratio > 1 - and self.kv_cache_layout == "HND" + if not self.use_mla and ( + block_size_ratio > 1 or self.enable_permute_local_kv ): block_ids_for_blocksize_post_process[block_size_ratio].append( - meta.local_block_ids + meta.local_physical_block_ids ) - self.blocksize_post_process(block_ids_for_blocksize_post_process) - if len(block_ids_to_permute) > 0: - self.permute_device_kv(block_ids_to_permute) + for ( + block_size_ratio, + block_ids_list, + ) in block_ids_for_blocksize_post_process.items(): + self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter()