[PD][HeteroArch]Fix accuracy issue with CPU_ATTN as Decoder and Flash_ATTN as prefiller (#38935)

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
Chendi.Xue
2026-04-08 22:19:07 -05:00
committed by GitHub
parent aec18492d0
commit ef5a226819
3 changed files with 93 additions and 0 deletions

View File

@@ -523,6 +523,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
kv_cache_layout="HND",
block_size=self.block_size,
ssm_sizes=(0, 0),
attn_backend_name=self.backend_name,
),
remote_tp_rank=remote_tp_rank,
remote_tp_size=remote_tp_size,
@@ -972,6 +973,7 @@ class TestNixlHandshake:
kv_cache_layout=mismatched_layout,
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
)
with pytest.raises(RuntimeError):
@@ -1028,6 +1030,7 @@ class TestNixlHandshake:
kv_cache_layout="HND",
block_size=worker.block_size,
ssm_sizes=(0, 0),
attn_backend_name=worker.backend_name,
)
# We don't check layout for homogeneous TP and MLA for now, as the
@@ -2347,6 +2350,7 @@ def test_compatibility_hash_validation(
kv_cache_layout="HND",
block_size=prefill_block_size,
ssm_sizes=(0, 0),
attn_backend_name=decode_worker.backend_name,
)
handshake_payload = NixlHandshakePayload(
compatibility_hash=remote_hash,