[PD][HeteroArch]Fix accuracy issue with CPU_ATTN as Decoder and Flash_ATTN as prefiller (#38935)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -523,6 +523,7 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
kv_cache_layout="HND",
|
||||
block_size=self.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=self.backend_name,
|
||||
),
|
||||
remote_tp_rank=remote_tp_rank,
|
||||
remote_tp_size=remote_tp_size,
|
||||
@@ -972,6 +973,7 @@ class TestNixlHandshake:
|
||||
kv_cache_layout=mismatched_layout,
|
||||
block_size=worker.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=worker.backend_name,
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
@@ -1028,6 +1030,7 @@ class TestNixlHandshake:
|
||||
kv_cache_layout="HND",
|
||||
block_size=worker.block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=worker.backend_name,
|
||||
)
|
||||
|
||||
# We don't check layout for homogeneous TP and MLA for now, as the
|
||||
@@ -2347,6 +2350,7 @@ def test_compatibility_hash_validation(
|
||||
kv_cache_layout="HND",
|
||||
block_size=prefill_block_size,
|
||||
ssm_sizes=(0, 0),
|
||||
attn_backend_name=decode_worker.backend_name,
|
||||
)
|
||||
handshake_payload = NixlHandshakePayload(
|
||||
compatibility_hash=remote_hash,
|
||||
|
||||
@@ -173,6 +173,7 @@ class NixlAgentMetadata:
|
||||
kv_cache_layout: str
|
||||
block_size: int
|
||||
ssm_sizes: tuple[int, int]
|
||||
attn_backend_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1116,6 +1117,7 @@ class NixlConnectorWorker:
|
||||
|
||||
self.num_blocks = kv_cache_config.num_blocks
|
||||
self.enable_permute_local_kv = False
|
||||
self.enable_heterogeneous_attn_post_process = False
|
||||
|
||||
# KV Caches and nixl tracking data.
|
||||
self.device_type = current_platform.device_type
|
||||
@@ -1776,6 +1778,7 @@ class NixlConnectorWorker:
|
||||
else self.host_buffer_kv_cache_layout,
|
||||
block_size=self.block_size,
|
||||
ssm_sizes=self._mamba_ssm_size,
|
||||
attn_backend_name=self.backend_name,
|
||||
)
|
||||
# Wrap metadata in payload with hash for defensive decoding
|
||||
assert self.compat_hash is not None
|
||||
@@ -2369,6 +2372,21 @@ class NixlConnectorWorker:
|
||||
"Or enable experimental feature to use HND to NHD support by "
|
||||
"setting 'enable_permute_local_kv'=True in --kv-transfer-config."
|
||||
)
|
||||
# if remote_agent used attn is not same as local,
|
||||
# hint heterogenuous attn post process
|
||||
if (
|
||||
nixl_agent_meta.attn_backend_name != self.backend_name
|
||||
and self.backend_name in ["CPU_ATTN"]
|
||||
):
|
||||
if self._is_hma_required:
|
||||
raise RuntimeError(
|
||||
"heterogeneous attn post process is not supported with HMA"
|
||||
)
|
||||
logger.info(
|
||||
"[Experimental] CPU_ATTN backend is used, "
|
||||
"hint heterogeneous attn post process"
|
||||
)
|
||||
self.enable_heterogeneous_attn_post_process = True
|
||||
|
||||
# Heterogeneous TP requires head-splitting, which only works with
|
||||
# HND layout. MLA and replicated-KV cases don't split on heads.
|
||||
@@ -2542,6 +2560,28 @@ class NixlConnectorWorker:
|
||||
cache, indices, block_size_ratio
|
||||
)
|
||||
|
||||
def post_process_device_kv_on_receive_heterogeneous_attn(
|
||||
self, block_ids: list[int]
|
||||
):
|
||||
"""
|
||||
Post process device kv cache after receiving from remote
|
||||
for heterogeneous attention.
|
||||
"""
|
||||
assert self.enable_heterogeneous_attn_post_process
|
||||
|
||||
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
|
||||
|
||||
for _, cache_or_caches in self.device_kv_caches.items():
|
||||
blocks_to_update = cache_or_caches.index_select(1, indices)
|
||||
current_platform.pack_kv_cache(
|
||||
key=blocks_to_update[0],
|
||||
value=blocks_to_update[1],
|
||||
key_cache=cache_or_caches[0],
|
||||
value_cache=cache_or_caches[1],
|
||||
block_ids=block_ids,
|
||||
indices=indices,
|
||||
)
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Get requests that are done sending or recving on this specific worker.
|
||||
@@ -2566,6 +2606,7 @@ class NixlConnectorWorker:
|
||||
)
|
||||
|
||||
block_ids_for_blocksize_post_process = defaultdict(list)
|
||||
block_ids_for_heterogeneous_attn_post_process = list[list[int]]()
|
||||
for req_id in done_recving:
|
||||
# clean up metadata for completed requests
|
||||
meta = self._recving_metadata.pop(req_id, None)
|
||||
@@ -2585,12 +2626,20 @@ class NixlConnectorWorker:
|
||||
block_ids_for_blocksize_post_process[block_size_ratio].append(
|
||||
meta.local_physical_block_ids[0]
|
||||
)
|
||||
# post processing for heterogeneous attention
|
||||
if self.enable_heterogeneous_attn_post_process:
|
||||
block_ids_for_heterogeneous_attn_post_process.append(
|
||||
meta.local_physical_block_ids[0]
|
||||
)
|
||||
for (
|
||||
block_size_ratio,
|
||||
block_ids_list,
|
||||
) in block_ids_for_blocksize_post_process.items():
|
||||
self.post_process_device_kv_on_receive(block_size_ratio, block_ids_list)
|
||||
|
||||
for block_ids in block_ids_for_heterogeneous_attn_post_process:
|
||||
self.post_process_device_kv_on_receive_heterogeneous_attn(block_ids)
|
||||
|
||||
# Handle timeout to avoid stranding blocks on remote.
|
||||
now = time.perf_counter()
|
||||
while self._reqs_to_send:
|
||||
|
||||
@@ -520,3 +520,43 @@ class CpuPlatform(Platform):
|
||||
import vllm._C # noqa: F401
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C: %r", e)
|
||||
|
||||
@classmethod
|
||||
def pack_kv_cache(
|
||||
cls,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
block_ids: list[int],
|
||||
indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Rewrite the kv cache shape for the current platform.
|
||||
"""
|
||||
# Import lazily: cpu_attn pulls in _custom_ops, which needs a fully
|
||||
# initialized vllm.platforms (avoid circular import while CpuPlatform loads).
|
||||
from vllm._custom_ops import cpu_attn_reshape_and_cache
|
||||
from vllm.v1.attention.backends.cpu_attn import _get_attn_isa
|
||||
|
||||
dtype = key.dtype
|
||||
# For CPU_ATTN, the shape is [N, num_kv_heads, block_size, head_size]
|
||||
_, _, block_size, head_size = key_cache.shape
|
||||
key = key.permute(0, 2, 1, 3).flatten(0, 1)
|
||||
value = value.permute(0, 2, 1, 3).flatten(0, 1)
|
||||
|
||||
isa = _get_attn_isa(dtype, block_size, head_size)
|
||||
block_offsets = torch.arange(block_size, device="cpu", dtype=torch.long)
|
||||
num_blocks = len(block_ids)
|
||||
slot_mapping = (
|
||||
block_offsets.reshape(1, block_size)
|
||||
+ indices.reshape(num_blocks, 1) * block_size
|
||||
).flatten()
|
||||
cpu_attn_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
isa,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user