diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index bc6246583..46eb366bc 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -980,8 +980,10 @@ def test_hybrid_block_table_initialization(): req_index = 0 block_table.append_row(kvcache_manager_blocks, req_index) # Get expected kernel blocks from the implementation for verification. - expected_kernel_blocks = block_table._map_to_kernel_blocks( - np.array(kvcache_manager_blocks) + expected_kernel_blocks = block_table.map_to_kernel_blocks( + np.array(kvcache_manager_blocks), + block_table.blocks_per_kv_block, + block_table._kernel_block_arange, ) # Verify block table state assert block_table.num_blocks_per_row[req_index] == len(expected_kernel_blocks) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index ff9770b72..f0a59cf35 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -48,6 +48,7 @@ from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backends.utils import get_kv_cache_layout from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -110,6 +111,8 @@ class NixlAgentMetadata(KVConnectorHandshakeMetadata): @dataclass class ReqMeta: local_block_ids: list[int] + # To be used when logical block size does not match the kernel block size + local_physical_block_ids: list[int] remote_block_ids: list[int] remote_host: str remote_port: int @@ -137,6 +140,7 @@ class NixlConnectorMetadata(KVConnectorMetadata): assert load_remote_cache ^ save_to_host _req = ReqMeta( local_block_ids=local_block_ids, + local_physical_block_ids=local_block_ids, remote_block_ids=kv_transfer_params["remote_block_ids"], remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], @@ -897,6 +901,8 @@ class NixlConnectorWorker: is_mla=self.use_mla, total_num_kv_heads=self.model_config.get_total_num_kv_heads(), ) + self._use_pallas = self.kv_topo._use_pallas + self._physical_blocks_per_logical_kv_block = 1 def _nixl_handshake( self, @@ -1092,6 +1098,22 @@ class NixlConnectorWorker: if base_addr in seen_base_addresses: continue + # TODO (NickLucche): Get kernel_block_size in a cleaner way + # NHD default "view" for non-MLA cache + kernel_block_size = cache.shape[-2] if self.use_mla else cache.shape[-3] + + if self.block_size != kernel_block_size: + logger.info_once( + "User-specified logical block size (%s) does not match" + " physical kernel block size (%s). Using the latter. ", + self.block_size, + kernel_block_size, + ) + self._physical_blocks_per_logical_kv_block = ( + self.block_size // kernel_block_size + ) + self.block_size = kernel_block_size + seen_base_addresses.append(base_addr) curr_tensor_size_bytes = cache.numel() * cache.element_size() @@ -1438,7 +1460,7 @@ class NixlConnectorWorker: assert self.use_host_buffer assert self.copy_blocks is not None - local_block_ids = meta.local_block_ids + local_block_ids = meta.local_physical_block_ids self.copy_blocks( self.host_xfer_buffers, self.device_kv_caches, @@ -1451,7 +1473,7 @@ class NixlConnectorWorker: "synced recved kv of request[%s] to device kv buffer," "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids)), + ",".join(map(str, local_block_ids)), ) def save_kv_to_host(self, metadata: NixlConnectorMetadata): @@ -1460,19 +1482,22 @@ class NixlConnectorWorker: assert self.copy_blocks is not None for req_id, meta in metadata.reqs_to_save.items(): + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids + ) if logger.isEnabledFor(logging.DEBUG): logger.debug( "save_load_kv for request[%s] to host xfer buffer." "local_block_ids: %s. ", req_id, - ",".join(map(str, meta.local_block_ids)), + ",".join(map(str, meta.local_physical_block_ids)), ) # blocking self.copy_blocks( self.device_kv_caches, self.host_xfer_buffers, - meta.local_block_ids, - meta.local_block_ids, + meta.local_physical_block_ids, + meta.local_physical_block_ids, "d2h", ) @@ -1541,7 +1566,7 @@ class NixlConnectorWorker: if self.use_host_buffer: self.sync_recved_kv_to_device(req_id, meta) if self.enable_permute_local_kv: - block_ids_to_permute += meta.local_block_ids + block_ids_to_permute += meta.local_physical_block_ids if len(block_ids_to_permute) > 0: self.permute_device_kv(block_ids_to_permute) @@ -1628,7 +1653,7 @@ class NixlConnectorWorker: req_id, xfer_state, ) - # mark all blocks for this request as invalid + # mark all (logical)blocks for this request as invalid if meta := self._recving_metadata.pop(req_id, None): self._invalid_block_ids.update(meta.local_block_ids) self._recving_metadata.pop(req_id, None) @@ -1645,13 +1670,19 @@ class NixlConnectorWorker: We check for these trnxs to complete in each step(). """ for req_id, meta in metadata.reqs_to_recv.items(): + meta.local_physical_block_ids = self._logical_to_kernel_block_ids( + meta.local_block_ids + ) + meta.remote_block_ids = self._logical_to_kernel_block_ids( + meta.remote_block_ids + ) remote_engine_id = meta.remote_engine_id logger.debug( "start_load_kv for request %s from remote engine %s. " "Num local_block_ids: %s. Num remote_block_ids: %s. ", req_id, remote_engine_id, - len(meta.local_block_ids), + len(meta.local_physical_block_ids), len(meta.remote_block_ids), ) # always store metadata for failure recovery @@ -1699,7 +1730,7 @@ class NixlConnectorWorker: self._read_blocks( request_id=req_id, dst_engine_id=meta.remote_engine_id, - local_block_ids=meta.local_block_ids, + local_block_ids=meta.local_physical_block_ids, remote_block_ids=meta.remote_block_ids, ) @@ -1826,7 +1857,7 @@ class NixlConnectorWorker: "Marking blocks as invalid.", request_id, ) - # mark all blocks for this request as invalid + # mark all (logical) blocks for this request as invalid if meta := self._recving_metadata.get(request_id): self._invalid_block_ids.update(meta.local_block_ids) self.xfer_stats.record_failed_transfer() @@ -1865,6 +1896,23 @@ class NixlConnectorWorker: descs_ids = region_ids * num_blocks + block_ids return descs_ids.flatten() + def _logical_to_kernel_block_ids(self, block_ids: list[int]) -> list[int]: + """ + Convert logical block ids to kernel physical block ids. + This is required when the logical block size (the one set by the user) + does not match the one required by the attn backend. + """ + if self._physical_blocks_per_logical_kv_block == 1: + # Noop when physical and logical block sizes are the same + return block_ids + block_ids_np = np.array(block_ids) + block_arange = np.arange(0, self._physical_blocks_per_logical_kv_block).reshape( + 1, -1 + ) + return BlockTable.map_to_kernel_blocks( + block_ids_np, self._physical_blocks_per_logical_kv_block, block_arange + ).tolist() + def get_backend_aware_kv_block_len(self, layer_idx: int): """ Get the block length for one K/V element (K and V have the same size). diff --git a/vllm/v1/worker/block_table.py b/vllm/v1/worker/block_table.py index c28bf542f..9f6c19e46 100644 --- a/vllm/v1/worker/block_table.py +++ b/vllm/v1/worker/block_table.py @@ -98,7 +98,9 @@ class BlockTable: return if self.use_hybrid_blocks: - block_ids = self._map_to_kernel_blocks(np.array(block_ids)) + block_ids = self.map_to_kernel_blocks( + np.array(block_ids), self.blocks_per_kv_block, self._kernel_block_arange + ) num_blocks = len(block_ids) start = self.num_blocks_per_row[row_idx] @@ -188,7 +190,12 @@ class BlockTable: self.block_table.gpu.fill_(0) self.block_table.cpu.fill_(0) - def _map_to_kernel_blocks(self, kv_manager_block_ids: np.ndarray) -> np.ndarray: + @staticmethod + def map_to_kernel_blocks( + kv_manager_block_ids: np.ndarray, + blocks_per_kv_block: int, + kernel_block_arange: np.ndarray, + ) -> np.ndarray: """Convert kv_manager_block_id IDs to kernel block IDs. Example: @@ -203,12 +210,12 @@ class BlockTable: # kv_manager_block_id 1 → kernel block id [2, 3] # kv_manager_block_id 2 → kernel block id [4, 5] """ - if not self.use_hybrid_blocks: + if blocks_per_kv_block == 1: return kv_manager_block_ids kernel_block_ids = ( - kv_manager_block_ids.reshape(-1, 1) * self.blocks_per_kv_block - + self._kernel_block_arange + kv_manager_block_ids.reshape(-1, 1) * blocks_per_kv_block + + kernel_block_arange ) return kernel_block_ids.reshape(-1)