Revert "Enable Cross layers KV cache layout at NIXL Connector (#30207)" (#33241)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Kevin H. Luu <khluu000@gmail.com>
(cherry picked from commit 2e8de86777)
This commit is contained in:
Or Ozeri
2026-01-28 14:36:00 +02:00
committed by khluu
parent 5f7f9ea884
commit fe18ce4d3f
5 changed files with 88 additions and 307 deletions

View File

@@ -184,15 +184,6 @@ Support use case: Prefill with 'HND' and decode with 'NHD' with experimental con
--kv-transfer-config '{..., "enable_permute_local_kv":"True"}'
```
### Cross layers blocks
By default, this feature is disabled. On attention backends that support this feature, each logical block is contiguous in physical memory. This reduces the number of buffers that need to be transferred.
To enable this feature:
```bash
--kv-transfer-config '{..., "kv_connector_extra_config": {"enable_cross_layers_blocks": "True"}}'
```
## Example Scripts/Code
Refer to these example scripts in the vLLM repository:

View File

@@ -34,18 +34,11 @@ else
KV_CONFIG_HETERO_LAYOUT=''
fi
CROSS_LAYERS_BLOCKS=${CROSS_LAYERS_BLOCKS:-"False"} # Default to non cross layers
if [[ "$CROSS_LAYERS_BLOCKS" == "True" ]]; then
KV_EXTRA_CONFIG=',"kv_connector_extra_config":{"cross_layers_blocks": "True"}'
else
KV_EXTRA_CONFIG=''
fi
# Build the kv-transfer-config once
if [[ "$KV_BUFFER_DEVICE" == "cuda" ]]; then
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}'}'
KV_CONFIG='{"kv_connector":"NixlConnector","kv_role":"kv_both"'${KV_CONFIG_HETERO_LAYOUT}'}'
else
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}${KV_EXTRA_CONFIG}"}"
KV_CONFIG="{\"kv_connector\":\"NixlConnector\",\"kv_role\":\"kv_both\",\"kv_buffer_device\":\"$KV_BUFFER_DEVICE\""${KV_CONFIG_HETERO_LAYOUT}"}"
fi
# Models to run

View File

@@ -18,12 +18,8 @@ import ray
import torch
from vllm import LLM
from vllm.config import KVTransferConfig, set_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
KVOutputAggregator,
TpKVTopology,
get_current_attn_backend,
)
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.kv_transfer.kv_connector.v1 import nixl_connector
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorStats
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
@@ -52,11 +48,8 @@ from vllm.sampling_params import SamplingParams
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheTensor
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import RequestStatus
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
from vllm.v1.worker.utils import AttentionGroup
from .utils import create_request, create_scheduler, create_vllm_config
@@ -373,7 +366,6 @@ def test_kv_transfer_handshake(dist_init):
# Decode connector will be able to create handshake with the prefill connector.
decode_connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
decode_connector.register_kv_caches(kv_caches)
# Here we are testing the retrieval of NIXLAgentMetadata.
# Knowing the implementation detail, we override the add_remote_agent
@@ -410,23 +402,6 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
self.kv_cache_layout = kv_cache_layout
# Mock register_kv_caches attribute needed for tests that do not call it.
self.src_xfer_handles_by_block_size = {self.block_size: 1}
test_shape = self.attn_backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=test_shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
def _nixl_handshake(
self, host: str, port: int, remote_tp_size: int, expected_engine_id: str
@@ -1395,7 +1370,6 @@ def _run_abort_timeout_test(llm: LLM, timeout: int):
),
),
"TRITON_ATTN",
"FLASHINFER",
],
)
def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
@@ -1412,11 +1386,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
vllm_config = create_vllm_config(attention_backend=attn_backend)
# Enable cross layers blocks
vllm_config.kv_transfer_config.kv_connector_extra_config[
"enable_cross_layers_blocks"
] = True
# Import the appropriate backend based on the parameter
if attn_backend == "FLASH_ATTN":
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
@@ -1426,11 +1395,49 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
from vllm.v1.attention.backends.rocm_attn import RocmAttentionBackend
backend_cls = RocmAttentionBackend
else: # TRITON
else: # TRITON_ATTN
from vllm.v1.attention.backends.triton_attn import TritonAttentionBackend
backend_cls = TritonAttentionBackend
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = shared_tensor.element_size() * shared_tensor.numel()
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
nixl_module = "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
with (
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
@@ -1459,107 +1466,6 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
# Reassure the shutdown() check that the thread is terminated
mock_thread.return_value.is_alive.return_value = False
expected_tensor_size: int
expected_base_addrs: list[int]
expected_num_entries: int
kv_caches: dict[str, torch.Tensor]
if connector.prefer_cross_layer_blocks:
num_layers = 32
block_size = 16
num_blocks = 8
kv_cache_spec = AttentionSpec(
block_size=block_size,
num_kv_heads=4,
head_size=64,
dtype=torch.bfloat16,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[
KVCacheTensor(
size=kv_cache_spec.page_size_bytes * num_blocks,
shared_by=["dummy-layer"],
)
for i in range(num_layers)
],
# allocate_uniform_kv_caches does not use this
kv_cache_groups=[],
)
with set_current_vllm_config(vllm_config):
_, cross_layers_kv_cache, _ = (
KVConnectorModelRunnerMixin.allocate_uniform_kv_caches(
kv_cache_config=kv_cache_config,
attn_groups=[
[
AttentionGroup(
backend=backend_cls,
layer_names=[],
kv_cache_spec=kv_cache_spec,
kv_cache_group_id=0,
)
]
],
cache_dtype=torch.bfloat16,
device=torch.cuda.current_device(),
kernel_block_sizes=[block_size],
)
)
# Store tensor info for validation
expected_tensor_size = (
cross_layers_kv_cache.element_size() * cross_layers_kv_cache.numel()
)
expected_base_addrs = [
cross_layers_kv_cache.data_ptr(),
]
expected_num_entries = 1
expected_blocks_count = 8
kv_caches = {"all-layers": cross_layers_kv_cache}
else:
# Create test kv cache tensors using proper backend shape
kv_cache_shape = backend_cls.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
# Store tensor info for validation
test_shape = backend_cls.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
is_blocks_first = len(test_shape) == 5 and test_shape[0] == 1
if is_blocks_first:
expected_tensor_size = (
shared_tensor.element_size() * shared_tensor.numel()
)
expected_base_addrs = [
shared_tensor.data_ptr(),
unique_tensor.data_ptr(),
]
expected_num_entries = 2
else:
expected_tensor_size = (
shared_tensor[0].element_size() * shared_tensor[0].numel()
)
expected_base_addrs = [
shared_tensor[0].data_ptr(),
shared_tensor[1].data_ptr(),
unique_tensor[0].data_ptr(),
unique_tensor[1].data_ptr(),
]
expected_num_entries = 4
expected_blocks_count = 8
# Execute register_kv_caches
connector.register_kv_caches(kv_caches)
@@ -1583,19 +1489,16 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
blocks_data, _ = mock_wrapper_instance.get_xfer_descs.call_args[0]
# Validate blocks_data structure and size
expected_blocks_count = 8
assert len(blocks_data) == expected_blocks_count, (
f"Expected {expected_blocks_count} blocks, got {len(blocks_data)}"
)
if connector.prefer_cross_layer_blocks:
num_blocks = 8
expected_block_len = expected_tensor_size // num_blocks
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
num_blocks = 2
if is_blocks_first:
expected_block_len = expected_tensor_size // num_blocks // 2
else:
expected_block_len = expected_tensor_size // num_blocks
expected_block_len = expected_tensor_size // num_blocks
for i, block_entry in enumerate(blocks_data):
block_start_addr, block_len, tp_rank = block_entry
@@ -2146,17 +2049,6 @@ def test_compatibility_hash_validation(
)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
kv_cache_shape = decode_worker.attn_backend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
shared_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
unique_tensor = torch.zeros(*kv_cache_shape, dtype=torch.float16)
kv_caches = {
"layer0": shared_tensor,
"layer1": unique_tensor,
"layer2": shared_tensor,
}
decode_connector.register_kv_caches(kv_caches)
remote_config_params: dict[str, Any] = {
"model": "facebook/opt-125m",
@@ -2179,9 +2071,7 @@ def test_compatibility_hash_validation(
)
)
remote_hash = compute_nixl_compatibility_hash(
remote_vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
remote_vllm_config, decode_worker.backend_name
)
prefill_block_size = config_overrides.get("block_size", 16)
@@ -2260,27 +2150,6 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
decode_connector = NixlConnector(local_vllm_config, KVConnectorRole.WORKER)
decode_worker = decode_connector.connector_worker
backend = get_current_attn_backend(local_vllm_config)
test_shape = backend.get_kv_cache_shape(
num_blocks=1, block_size=16, num_kv_heads=1, head_size=1
)
decode_worker.kv_topo = TpKVTopology(
tp_rank=decode_worker.tp_rank,
engine_id=decode_worker.engine_id,
remote_tp_size=decode_worker._tp_size, # shared state
remote_block_size=decode_worker._block_size, # shared state
is_mla=decode_worker.use_mla,
total_num_kv_heads=decode_worker.model_config.get_total_num_kv_heads(),
attn_backend=backend,
tensor_shape=test_shape,
)
decode_worker.compat_hash = compute_nixl_compatibility_hash(
decode_worker.vllm_config,
decode_worker.backend_name,
decode_worker.kv_topo.cross_layers_blocks,
)
if error_scenario == "handshake_decode_error":
msg_bytes = b"this is not valid msgpack data"
elif error_scenario == "handshake_validation_error":

View File

@@ -316,7 +316,6 @@ class TpKVTopology:
attn_backend: type[AttentionBackend]
engine_id: EngineId
remote_block_size: dict[EngineId, int]
tensor_shape: torch.Size | None = None
def __post_init__(self):
# Figure out whether the first dimension of the cache is K/V
@@ -330,32 +329,6 @@ class TpKVTopology:
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
)
self._kv_heads_position: int | None = None
self._cross_layers_blocks = False
if self.tensor_shape is not None:
self._cross_layers_blocks = (
len(self.tensor_shape) == len(kv_cache_shape) + 1
)
if self._cross_layers_blocks:
# prepend layers dimension
kv_cache_shape = (80,) + kv_cache_shape
try:
kv_cache_stride_order = self.attn_backend.get_kv_cache_stride_order(
include_num_layers_dimension=self._cross_layers_blocks
)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(range(len(self.tensor_shape)))
# permute kv_cache_shape according to stride_order
kv_cache_shape = tuple(kv_cache_shape[i] for i in kv_cache_stride_order)
physical_block_size_position = kv_cache_shape.index(16)
assert physical_block_size_position is not None
self._physical_block_size_position = -(
len(kv_cache_shape) - physical_block_size_position
)
@property
def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first
@@ -363,9 +336,7 @@ class TpKVTopology:
@property
def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present).
return not (
self._cross_layers_blocks or self.is_mla or self.is_kv_layout_blocks_first
)
return not (self.is_mla or self.is_kv_layout_blocks_first)
@property
def tp_size(self) -> int:
@@ -375,14 +346,6 @@ class TpKVTopology:
def block_size(self) -> int:
return self.remote_block_size[self.engine_id]
@property
def cross_layers_blocks(self) -> bool:
return self._cross_layers_blocks
@property
def block_size_position(self) -> int:
return self._physical_block_size_position
def tp_ratio(
self,
remote_tp_size: int,

View File

@@ -54,7 +54,7 @@ from vllm.forward_context import ForwardContext
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
from vllm.v1.attention.backend import AttentionBackend, AttentionMetadata
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.utils import get_kv_cache_layout
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.block_table import BlockTable
@@ -173,7 +173,7 @@ class NixlHandshakePayload(KVConnectorHandshakeMetadata):
def compute_nixl_compatibility_hash(
vllm_config: VllmConfig, attn_backend_name: str, cross_layers_blocks: bool
vllm_config: VllmConfig, attn_backend_name: str
) -> str:
"""
Compute compatibility hash for NIXL KV transfer.
@@ -216,7 +216,6 @@ def compute_nixl_compatibility_hash(
# Attention backend and KV cache dtype affect memory layout
"attn_backend_name": attn_backend_name,
"cache_dtype": str(cache_config.cache_dtype),
"cross_layers_blocks": cross_layers_blocks,
}
compat_hash = hash_factors(factors)
@@ -299,20 +298,6 @@ class NixlConnectorMetadata(KVConnectorMetadata):
class NixlConnector(KVConnectorBase_V1):
@property
def prefer_cross_layer_blocks(self) -> bool:
backend = get_current_attn_backend(self._vllm_config)
if backend().get_name() not in (
"FLASH_ATTN",
"FLASHINFER",
):
# For now there is no benefit to run cross layers when backend
# does not support on HND
return False
extra_config = self.kv_transfer_config.kv_connector_extra_config
return bool(str(extra_config.get("enable_cross_layers_blocks", "False")))
def __init__(
self,
vllm_config: VllmConfig,
@@ -324,7 +309,6 @@ class NixlConnector(KVConnectorBase_V1):
assert vllm_config.kv_transfer_config is not None
assert vllm_config.kv_transfer_config.engine_id is not None
self.engine_id: EngineId = vllm_config.kv_transfer_config.engine_id
self.kv_transfer_config = vllm_config.kv_transfer_config
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: NixlConnectorScheduler | None = (
@@ -411,16 +395,6 @@ class NixlConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
assert self.connector_worker is not None
cross_layer_name = "ALL_LAYERS"
kv_caches = {cross_layer_name: kv_cache}
self.connector_worker.register_kv_caches(kv_caches)
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
assert self.connector_worker is not None
self.connector_worker.set_host_xfer_buffer_ops(copy_operation)
@@ -1002,17 +976,20 @@ class NixlConnectorWorker:
# Get the attention backend from the first layer
# NOTE (NickLucche) models with multiple backends are not supported yet
self.attn_backend = get_current_attn_backend(vllm_config)
backend = get_current_attn_backend(vllm_config)
self.backend_name = self.attn_backend.get_name()
self.backend_name = backend.get_name()
self.kv_cache_layout = get_kv_cache_layout()
self.host_buffer_kv_cache_layout = self.kv_cache_layout
logger.debug("Detected attention backend %s", self.backend_name)
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
# lazy initialized in register_kv_caches
self.compat_hash: str | None = None
self.kv_topo: TpKVTopology | None = None
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name
)
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
)
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
self._block_size: dict[EngineId, int] = {self.engine_id: self.block_size}
@@ -1021,11 +998,16 @@ class NixlConnectorWorker:
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
self.xfer_stats = NixlKVConnectorStats()
self._physical_blocks_per_logical_kv_block = 1
self.enforce_compat_hash = self.kv_transfer_config.get_from_extra_config(
"enforce_handshake_compat", True
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend,
)
self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake(
self,
@@ -1040,7 +1022,6 @@ class NixlConnectorWorker:
# Regardless, only handshake with the remote TP rank(s) that current
# local rank will read from. Note that With homogeneous TP,
# this happens to be the same single rank_i.
assert self.kv_topo is not None
p_remote_ranks = self.kv_topo.get_target_remote_ranks(remote_tp_size)
remote_rank_to_agent_name = {}
path = make_zmq_path("tcp", host, port)
@@ -1078,7 +1059,6 @@ class NixlConnectorWorker:
)
# Check compatibility hash BEFORE decoding agent metadata
assert self.compat_hash is not None
if (
self.enforce_compat_hash
and handshake_payload.compatibility_hash != self.compat_hash
@@ -1287,20 +1267,6 @@ class NixlConnectorWorker:
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
self.kv_topo = TpKVTopology(
tp_rank=self.tp_rank,
engine_id=self.engine_id,
remote_tp_size=self._tp_size, # shared state
remote_block_size=self._block_size, # shared state
is_mla=self.use_mla,
total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=self.attn_backend,
tensor_shape=next(iter(kv_caches.values())).shape,
)
self.compat_hash = compute_nixl_compatibility_hash(
self.vllm_config, self.backend_name, self.kv_topo.cross_layers_blocks
)
if self.use_host_buffer:
self.initialize_host_xfer_buffer(kv_caches=kv_caches)
assert len(self.host_xfer_buffers) == len(kv_caches), (
@@ -1335,21 +1301,29 @@ class NixlConnectorWorker:
# (roughly 8KB vs 5KB).
# Conversely for FlashInfer, K and V are registered in the same region
# to better exploit the memory layout (ie num_blocks is the first dim).
split_k_and_v = self.kv_topo.split_k_and_v
tensor_size_bytes = None
# TODO (NickLucche): Get kernel_block_size in a cleaner way
# NHD default "view" for non-MLA cache
if self.device_type == "cpu":
block_size_position = -2
else:
block_size_position = -2 if self.use_mla else -3
# Enable different block lengths for different layers when MLA is used.
self.block_len_per_layer = list[int]()
self.slot_size_per_layer = list[int]() # HD bytes in kv terms
for layer_name, cache_or_caches in xfer_buffers.items():
cache_list = (
cache_or_caches if self.kv_topo.split_k_and_v else [cache_or_caches]
)
cache_list = cache_or_caches if split_k_and_v else [cache_or_caches]
for cache in cache_list:
base_addr = cache.data_ptr()
if base_addr in seen_base_addresses:
continue
kernel_block_size = cache.shape[self.kv_topo.block_size_position]
kernel_block_size = cache.shape[block_size_position]
if self.block_size != kernel_block_size:
logger.info_once(
"User-specified logical block size (%s) does not match"
@@ -1411,7 +1385,6 @@ class NixlConnectorWorker:
self.device_kv_caches = kv_caches
self.dst_num_blocks[self.engine_id] = self.num_blocks
if self.kv_topo.is_kv_layout_blocks_first:
for i in range(len(self.slot_size_per_layer)):
assert self.slot_size_per_layer[i] % 2 == 0
@@ -1467,7 +1440,6 @@ class NixlConnectorWorker:
block_size=self.block_size,
)
# Wrap metadata in payload with hash for defensive decoding
assert self.compat_hash is not None
encoder = msgspec.msgpack.Encoder()
self.xfer_handshake_metadata = NixlHandshakePayload(
compatibility_hash=self.compat_hash,
@@ -1489,8 +1461,6 @@ class NixlConnectorWorker:
register another local_xfer_handler using remote block len to ensure
data copy correctness.
"""
assert self.kv_topo is not None
block_size_ratio = self.block_size // block_size
blocks_data = []
for i, base_addr in enumerate(self.seen_base_addresses):
@@ -1603,7 +1573,6 @@ class NixlConnectorWorker:
# remote: | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|
# local origin:| 0| 1| 8| 12|
# local mapped:| 0| 1| 2| 3| 4| 5| 6| 7| 8| 9|10|11|12|13|14|15|
assert self.kv_topo is not None
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(engine_id)
if engine_id not in self.dst_num_blocks:
@@ -1731,10 +1700,7 @@ class NixlConnectorWorker:
"""
remote_engine_id = nixl_agent_meta.engine_id
assert (
self._tp_size[remote_engine_id] == remote_tp_size
and self.kv_topo is not None
)
assert self._tp_size[remote_engine_id] == remote_tp_size
tp_ratio = self.kv_topo.tp_ratio_from_engine_id(remote_engine_id)
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(
@@ -1871,7 +1837,6 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0:
return
assert block_size_ratio >= 1, "Only nP < nD supported currently."
assert self.kv_topo is not None
if self.enable_permute_local_kv and block_size_ratio > 1:
logger.debug(
"Post-processing device kv cache on receive by converting "
@@ -1891,7 +1856,7 @@ class NixlConnectorWorker:
block_size_ratio,
)
split_k_and_v = self.kv_topo.split_k_and_v
split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
for block_ids in block_ids_list:
indices = torch.tensor(block_ids, device=self.device_type, dtype=torch.long)
@@ -1916,7 +1881,6 @@ class NixlConnectorWorker:
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
assert self.kv_topo is not None
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
@@ -1986,7 +1950,6 @@ class NixlConnectorWorker:
are reading from the same producer (heterogeneous TP scenario), wait
for all consumers to be done pulling.
"""
assert self.kv_topo is not None
notified_req_ids: set[str] = set()
for notifs in self.nixl_wrapper.get_new_notifs().values():
for notif in notifs:
@@ -2146,7 +2109,7 @@ class NixlConnectorWorker:
self._reqs_to_send[req_id] = expiration_time
def _read_blocks_for_req(self, req_id: str, meta: ReqMeta):
assert meta.remote is not None and self.kv_topo is not None
assert meta.remote is not None
remote_ranks = self.kv_topo.get_target_remote_ranks_from_engine_id(
meta.remote.engine_id
)
@@ -2215,7 +2178,10 @@ class NixlConnectorWorker:
local_xfer_side_handle: int,
remote_xfer_side_handle: int,
):
assert self.kv_topo is not None
"""
Post a READ point-to-point xfer request from a single local worker to
a single remote worker.
"""
block_size_ratio = self.kv_topo.block_size_ratio_from_engine_id(dst_engine_id)
if block_size_ratio > 1:
local_block_ids = self.get_mapped_blocks(
@@ -2448,7 +2414,6 @@ class NixlConnectorWorker:
For FlashInfer, this is half the length of the whole block, as K and V
share the same region.
"""
assert self.kv_topo is not None
if self.kv_topo.is_kv_layout_blocks_first:
# For indexing only half (either just the K or V part).
block_len = self.block_len_per_layer[layer_idx] // 2