[PD][HeteroArch]Fix accuracy issue with CPU_ATTN as Decoder and Flash_ATTN as prefiller (#38935)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -523,6 +523,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
kv_cache_layout="HND",
|
||||
block_size=self.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=self.backend_name,
|
||||
),
|
||||
remote_tp_rank=remote_tp_rank,
|
||||
remote_tp_size=remote_tp_size,
|
||||
@@ -972,6 +973,7 @@ class TestNixlHandshake:
|
||||
kv_cache_layout=mismatched_layout,
|
||||
block_size=worker.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=worker.backend_name,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -1028,6 +1030,7 @@ class TestNixlHandshake:
|
||||
kv_cache_layout="HND",
|
||||
block_size=worker.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=worker.backend_name,
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
@@ -2347,6 +2350,7 @@ def test_compatibility_hash_validation(
|
||||
kv_cache_layout="HND",
|
||||
block_size=prefill_block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=decode_worker.backend_name,
|
||||
)
|
||||
handshake_payload = NixlHandshakePayload(
|
||||
compatibility_hash=remote_hash,
|
||||
|
||||
Reference in New Issue
Block a user