diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index f39b78dd2..fc199f006 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -523,6 +523,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker): kv_cache_layout="HND", block_size=self.block_size, ssm_sizes=(0, 0), + attn_backend_name=self.backend_name, ), remote_tp_rank=remote_tp_rank, remote_tp_size=remote_tp_size, @@ -972,6 +973,7 @@ class TestNixlHandshake: kv_cache_layout=mismatched_layout, block_size=worker.block_size, ssm_sizes=(0, 0), + attn_backend_name=worker.backend_name, ) with pytest.raises(RuntimeError): @@ -1028,6 +1030,7 @@ class TestNixlHandshake: kv_cache_layout="HND", block_size=worker.block_size, ssm_sizes=(0, 0), + attn_backend_name=worker.backend_name, ) # We don't check layout for homogeneous TP and MLA for now, as the @@ -2347,6 +2350,7 @@ def test_compatibility_hash_validation( kv_cache_layout="HND", block_size=prefill_block_size, ssm_sizes=(0, 0), + attn_backend_name=decode_worker.backend_name, ) handshake_payload = NixlHandshakePayload( compatibility_hash=remote_hash, diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index c575043fb..54cf9805b 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -173,6 +173,7 @@ class NixlAgentMetadata: kv_cache_layout: str block_size: int ssm_sizes: tuple[int, int] + attn_backend_name: str @dataclass @@ -1116,6 +1117,7 @@ class NixlConnectorWorker: self.num_blocks = kv_cache_config.num_blocks self.enable_permute_local_kv = False + self.enable_heterogeneous_attn_post_process = False # KV Caches and nixl tracking data. self.device_type = current_platform.device_type @@ -1776,6 +1778,7 @@ class NixlConnectorWorker: else self.host_buffer_kv_cache_layout, block_size=self.block_size, ssm_sizes=self._mamba_ssm_size, + attn_backend_name=self.backend_name, ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -2369,6 +2372,21 @@ class NixlConnectorWorker: "Or enable experimental feature to use HND to NHD support by " "setting 'enable_permute_local_kv'=True in --kv-transfer-config." ) + # if remote_agent used attn is not same as local, + # hint heterogenuous attn post process + if ( + nixl_agent_meta.attn_backend_name != self.backend_name + and self.backend_name in ["CPU_ATTN"] + ): + if self._is_hma_required: + raise RuntimeError( + "heterogeneous attn post process is not supported with HMA" + ) + logger.info( + "[Experimental] CPU_ATTN backend is used, " + "hint heterogeneous attn post process" + ) + self.enable_heterogeneous_attn_post_process = True # Heterogeneous TP requires head-splitting, which only works with # HND layout. MLA and replicated-KV cases don't split on heads. @@ -2542,6 +2560,28 @@ class NixlConnectorWorker: cache, indices, block_size_ratio ) + def post_process_device_kv_on_receive_heterogeneous_attn( + self, block_ids: list[int] + ): + """ + Post process device kv cache after receiving from remote + for heterogeneous attention. + """ + assert self.enable_heterogeneous_attn_post_process + + indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long) + + for _, cache_or_caches in self.device_kv_caches.items(): + blocks_to_update = cache_or_caches.index_select(1, indices) + current_platform.pack_kv_cache( + key=blocks_to_update[0], + value=blocks_to_update[1], + key_cache=cache_or_caches[0], + value_cache=cache_or_caches[1], + block_ids=block_ids, + indices=indices, + ) + def get_finished(self) -> tuple[set[str], set[str]]: """ Get requests that are done sending or recving on this specific worker. @@ -2566,6 +2606,7 @@ class NixlConnectorWorker: ) block_ids_for_blocksize_post_process = defaultdict(list) + block_ids_for_heterogeneous_attn_post_process = list[list[int]]() for req_id in done_recving: # clean up metadata for completed requests meta = self._recving_metadata.pop(req_id, None) @@ -2585,12 +2626,20 @@ class NixlConnectorWorker: block_ids_for_blocksize_post_process[block_size_ratio].append( meta.local_physical_block_ids[0] ) + # post processing for heterogeneous attention + if self.enable_heterogeneous_attn_post_process: + block_ids_for_heterogeneous_attn_post_process.append( + meta.local_physical_block_ids[0] + ) for ( block_size_ratio, block_ids_list, ) in block_ids_for_blocksize_post_process.items(): self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list) + for block_ids in block_ids_for_heterogeneous_attn_post_process: + self.post_process_device_kv_on_receive_heterogeneous_attn(block_ids) + # Handle timeout to avoid stranding blocks on remote. now = time.perf_counter() while self._reqs_to_send: diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index cd6c093ce..b18185472 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -520,3 +520,43 @@ class CpuPlatform(Platform): import vllm._C # noqa: F401 except ImportError as e: logger.warning("Failed to import from vllm._C: %r", e) + + @classmethod + def pack_kv_cache( + cls, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_ids: list[int], + indices: torch.Tensor, + ) -> None: + """ + Rewrite the kv cache shape for the current platform. + """ + # Import lazily: cpu_attn pulls in _custom_ops, which needs a fully + # initialized vllm.platforms (avoid circular import while CpuPlatform loads). + from vllm._custom_ops import cpu_attn_reshape_and_cache + from vllm.v1.attention.backends.cpu_attn import _get_attn_isa + + dtype = key.dtype + # For CPU_ATTN, the shape is [N, num_kv_heads, block_size, head_size] + _, _, block_size, head_size = key_cache.shape + key = key.permute(0, 2, 1, 3).flatten(0, 1) + value = value.permute(0, 2, 1, 3).flatten(0, 1) + + isa = _get_attn_isa(dtype, block_size, head_size) + block_offsets = torch.arange(block_size, device="cpu", dtype=torch.long) + num_blocks = len(block_ids) + slot_mapping = ( + block_offsets.reshape(1, block_size) + + indices.reshape(num_blocks, 1) * block_size + ).flatten() + cpu_attn_reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + isa, + )