[PD][HeteroArch]Fix accuracy issue with CPU_ATTN as Decoder and Flash_ATTN as prefiller (#38935)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue
2026-04-08 22:19:07 -05:00
committed by GitHub
parent aec18492d0
commit ef5a226819
3 changed files with 93 additions and 0 deletions

View File

@@ -523,6 +523,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
kv_cache_layout="HND",
block_size=self.block_size,
ssm_sizes=(0, 0),
attn_backend_name=self.backend_name,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
@@ -972,6 +973,7 @@ class TestNixlHandshake:
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
)
with pytest.raises(RuntimeError):
@@ -1028,6 +1030,7 @@ class TestNixlHandshake:
kv_cache_layout="HND",
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
)
# We don't check layout for homogeneous TP and MLA for now, as the
@@ -2347,6 +2350,7 @@ def test_compatibility_hash_validation(
kv_cache_layout="HND",
block_size=prefill_block_size,
ssm_sizes=(0, 0),
attn_backend_name=decode_worker.backend_name,
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,

View File

@@ -173,6 +173,7 @@ class NixlAgentMetadata:
kv_cache_layout: str
block_size: int
ssm_sizes: tuple[int, int]
attn_backend_name: str
@dataclass
@@ -1116,6 +1117,7 @@ class NixlConnectorWorker:
self.num_blocks = kv_cache_config.num_blocks
self.enable_permute_local_kv = False
self.enable_heterogeneous_attn_post_process = False
# KV Caches and nixl tracking data.
self.device_type = current_platform.device_type
@@ -1776,6 +1778,7 @@ class NixlConnectorWorker:
else self.host_buffer_kv_cache_layout,
block_size=self.block_size,
ssm_sizes=self._mamba_ssm_size,
attn_backend_name=self.backend_name,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
@@ -2369,6 +2372,21 @@ class NixlConnectorWorker:
"Or enable experimental feature to use HND to NHD support by "
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
)
# if remote_agent used attn is not same as local,
# hint heterogenuous attn post process
if (
nixl_agent_meta.attn_backend_name != self.backend_name
and self.backend_name in ["CPU_ATTN"]
):
if self._is_hma_required:
raise RuntimeError(
"heterogeneous attn post process is not supported with HMA"
)
logger.info(
"[Experimental] CPU_ATTN backend is used, "
"hint heterogeneous attn post process"
)
self.enable_heterogeneous_attn_post_process = True
# Heterogeneous TP requires head-splitting, which only works with
# HND layout. MLA and replicated-KV cases don't split on heads.
@@ -2542,6 +2560,28 @@ class NixlConnectorWorker:
cache, indices, block_size_ratio
)
def post_process_device_kv_on_receive_heterogeneous_attn(
self, block_ids: list[int]
):
"""
Post process device kv cache after receiving from remote
for heterogeneous attention.
"""
assert self.enable_heterogeneous_attn_post_process
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
for _, cache_or_caches in self.device_kv_caches.items():
blocks_to_update = cache_or_caches.index_select(1, indices)
current_platform.pack_kv_cache(
key=blocks_to_update[0],
value=blocks_to_update[1],
key_cache=cache_or_caches[0],
value_cache=cache_or_caches[1],
block_ids=block_ids,
indices=indices,
)
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving on this specific worker.
@@ -2566,6 +2606,7 @@ class NixlConnectorWorker:
)
block_ids_for_blocksize_post_process = defaultdict(list)
block_ids_for_heterogeneous_attn_post_process = list[list[int]]()
for req_id in done_recving:
# clean up metadata for completed requests
meta = self._recving_metadata.pop(req_id, None)
@@ -2585,12 +2626,20 @@ class NixlConnectorWorker:
block_ids_for_blocksize_post_process[block_size_ratio].append(
meta.local_physical_block_ids[0]
)
# post processing for heterogeneous attention
if self.enable_heterogeneous_attn_post_process:
block_ids_for_heterogeneous_attn_post_process.append(
meta.local_physical_block_ids[0]
)
for (
block_size_ratio,
block_ids_list,
) in block_ids_for_blocksize_post_process.items():
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)
for block_ids in block_ids_for_heterogeneous_attn_post_process:
self.post_process_device_kv_on_receive_heterogeneous_attn(block_ids)
# Handle timeout to avoid stranding blocks on remote.
now = time.perf_counter()
while self._reqs_to_send:

View File

@@ -520,3 +520,43 @@ class CpuPlatform(Platform):
import vllm._C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._C: %r", e)
@classmethod
def pack_kv_cache(
cls,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_ids: list[int],
indices: torch.Tensor,
) -> None:
"""
Rewrite the kv cache shape for the current platform.
"""
# Import lazily: cpu_attn pulls in _custom_ops, which needs a fully
# initialized vllm.platforms (avoid circular import while CpuPlatform loads).
from vllm._custom_ops import cpu_attn_reshape_and_cache
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
dtype = key.dtype
# For CPU_ATTN, the shape is [N, num_kv_heads, block_size, head_size]
_, _, block_size, head_size = key_cache.shape
key = key.permute(0, 2, 1, 3).flatten(0, 1)
value = value.permute(0, 2, 1, 3).flatten(0, 1)
isa = _get_attn_isa(dtype, block_size, head_size)
block_offsets = torch.arange(block_size, device="cpu", dtype=torch.long)
num_blocks = len(block_ids)
slot_mapping = (
block_offsets.reshape(1, block_size)
+ indices.reshape(num_blocks, 1) * block_size
).flatten()
cpu_attn_reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping,
isa,
)