|
|
|
|
@@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
|
|
|
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
|
|
|
|
from vllm.v1.request import Request
|
|
|
|
|
|
|
|
|
|
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
|
|
|
|
GET_META_MSG = b"get_meta_msg"
|
|
|
|
|
|
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
@@ -54,6 +55,8 @@ class NixlAgentMetadata(
|
|
|
|
|
agent_metadata: bytes
|
|
|
|
|
kv_caches_base_addr: list[int]
|
|
|
|
|
num_blocks: int
|
|
|
|
|
tp_size: int
|
|
|
|
|
block_len: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -331,10 +334,14 @@ class NixlConnectorWorker:
|
|
|
|
|
logger.info("Initializing NIXL wrapper")
|
|
|
|
|
logger.info("Initializing NIXL worker %s", engine_id)
|
|
|
|
|
|
|
|
|
|
# Config.
|
|
|
|
|
self.vllm_config = vllm_config
|
|
|
|
|
self.block_size = vllm_config.cache_config.block_size
|
|
|
|
|
|
|
|
|
|
# Agent.
|
|
|
|
|
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
|
|
|
|
# Map of engine_id -> agent_name.
|
|
|
|
|
self._remote_agents: dict[str, str] = {}
|
|
|
|
|
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
|
|
|
|
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
|
|
|
|
|
|
|
|
|
|
# NIXL handshake port.
|
|
|
|
|
# NOTE(rob): Within a DP group, each DP rank gets its own
|
|
|
|
|
@@ -354,7 +361,8 @@ class NixlConnectorWorker:
|
|
|
|
|
# KV Caches and nixl tracking data.
|
|
|
|
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
|
|
|
|
|
|
|
|
|
# Map of engine_id -> kv_caches_base_addr
|
|
|
|
|
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
|
|
|
|
|
# rank will still only pull from a single remote TP worker.
|
|
|
|
|
self.kv_caches_base_addr: dict[str, list[int]] = {}
|
|
|
|
|
|
|
|
|
|
# Number of NIXL regions. Currently one region per cache
|
|
|
|
|
@@ -362,19 +370,19 @@ class NixlConnectorWorker:
|
|
|
|
|
self.num_regions = 0
|
|
|
|
|
self.num_layers = 0
|
|
|
|
|
|
|
|
|
|
# nixl_prepped_dlist_handle (int).
|
|
|
|
|
# nixl_prepped_dlist_handle.
|
|
|
|
|
self.src_xfer_side_handle: int = 0
|
|
|
|
|
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
|
|
|
|
|
self.dst_xfer_side_handles: dict[str, int] = {}
|
|
|
|
|
|
|
|
|
|
# Map of engine_id -> num_blocks.
|
|
|
|
|
# Map of engine_id -> num_blocks. All ranks in the same deployment will
|
|
|
|
|
# have the same number of blocks.
|
|
|
|
|
self.dst_num_blocks: dict[str, int] = {}
|
|
|
|
|
self._registered_descs: list[Any] = []
|
|
|
|
|
|
|
|
|
|
# In progress transfers.
|
|
|
|
|
# [req_id -> list[handle]]
|
|
|
|
|
self._recving_transfers: defaultdict[str, list[Any]] = defaultdict(
|
|
|
|
|
list[Any])
|
|
|
|
|
self._recving_transfers = defaultdict[str, list[Transfer]](list)
|
|
|
|
|
|
|
|
|
|
# Complete transfer tracker. Used by the rank 0 to track finished
|
|
|
|
|
# transactions on ranks 1 to N-1.
|
|
|
|
|
@@ -395,6 +403,11 @@ class NixlConnectorWorker:
|
|
|
|
|
# List of block window sizes for each layer for local attention
|
|
|
|
|
self.block_window_per_layer: list[Optional[int]] = []
|
|
|
|
|
|
|
|
|
|
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
|
|
|
|
|
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
|
|
|
|
# finish reading before safely freeing the blocks.
|
|
|
|
|
self.consumer_notification_counts_by_req = defaultdict[str, int](int)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
|
|
|
|
|
ready_event: threading.Event, base_port: int,
|
|
|
|
|
@@ -426,27 +439,44 @@ class NixlConnectorWorker:
|
|
|
|
|
"""Do a NIXL handshake with a remote instance."""
|
|
|
|
|
|
|
|
|
|
start_time = time.perf_counter()
|
|
|
|
|
# NOTE(rob): we need each tp_rank to have a unique port.
|
|
|
|
|
# This is a hack to keep us moving. We will switch when
|
|
|
|
|
# we switch to HTTP-based NIXL metadata exchange.
|
|
|
|
|
path = make_zmq_path("tcp", host, port + self.tp_rank)
|
|
|
|
|
logger.debug("Querying metadata on path: %s", path)
|
|
|
|
|
with zmq_ctx(zmq.REQ, path) as sock:
|
|
|
|
|
|
|
|
|
|
# NOTE(rob): we need each rank to have a unique port. This is
|
|
|
|
|
# a hack to keep us moving. We will switch when moving to etcd
|
|
|
|
|
# or where we have a single ZMQ socket in the scheduler.
|
|
|
|
|
|
|
|
|
|
def handshake(path: str, rank: int) -> NixlAgentMetadata:
|
|
|
|
|
# Send query for the request.
|
|
|
|
|
sock.send(GET_META_MSG)
|
|
|
|
|
metadata_bytes = sock.recv()
|
|
|
|
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
|
|
|
|
metadata = decoder.decode(metadata_bytes)
|
|
|
|
|
got_metadata_time = time.perf_counter()
|
|
|
|
|
with zmq_ctx(zmq.REQ, path) as sock:
|
|
|
|
|
sock.send(GET_META_MSG)
|
|
|
|
|
metadata_bytes = sock.recv()
|
|
|
|
|
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
|
|
|
|
|
metadata = decoder.decode(metadata_bytes)
|
|
|
|
|
got_metadata_time = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
# Register Remote agent.
|
|
|
|
|
self.add_remote_agent(metadata)
|
|
|
|
|
setup_agent_time = time.perf_counter()
|
|
|
|
|
# Register Remote agent.
|
|
|
|
|
self.add_remote_agent(metadata, rank)
|
|
|
|
|
setup_agent_time = time.perf_counter()
|
|
|
|
|
|
|
|
|
|
logger.debug("NIXL handshake: get metadata took: %s",
|
|
|
|
|
got_metadata_time - start_time)
|
|
|
|
|
logger.debug("NIXL handshake: add agent took: %s",
|
|
|
|
|
setup_agent_time - got_metadata_time)
|
|
|
|
|
logger.debug("NIXL handshake: get metadata took: %s",
|
|
|
|
|
got_metadata_time - start_time)
|
|
|
|
|
logger.debug("NIXL handshake: add agent took: %s",
|
|
|
|
|
setup_agent_time - got_metadata_time)
|
|
|
|
|
return metadata
|
|
|
|
|
|
|
|
|
|
# Handshake with remote agent-rank0 first to get the tp_size of remote
|
|
|
|
|
path = make_zmq_path("tcp", host, port)
|
|
|
|
|
logger.debug("Querying master rank metadata on path: %s", path)
|
|
|
|
|
metadata = handshake(path, 0)
|
|
|
|
|
|
|
|
|
|
# Handshake only with the other TP remote the current local rank will
|
|
|
|
|
# pull from. With homogeneous TP it happens to be the same rank_i.
|
|
|
|
|
tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size
|
|
|
|
|
p_remote_rank = self.tp_rank // tp_ratio
|
|
|
|
|
if p_remote_rank > 0:
|
|
|
|
|
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
|
|
|
|
logger.debug("Querying metadata on path: %s at remote rank %s",
|
|
|
|
|
path, p_remote_rank)
|
|
|
|
|
_ = handshake(path, p_remote_rank)
|
|
|
|
|
|
|
|
|
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
|
|
|
|
"""Register the KV Cache data in nixl."""
|
|
|
|
|
@@ -455,24 +485,34 @@ class NixlConnectorWorker:
|
|
|
|
|
kv_elem_size = first_kv_cache.element_size()
|
|
|
|
|
|
|
|
|
|
# TODO(tms): Find a more robust way to detect and handle MLA
|
|
|
|
|
use_mla = len(first_kv_cache.shape) == 3
|
|
|
|
|
if use_mla:
|
|
|
|
|
self.use_mla = len(first_kv_cache.shape) == 3
|
|
|
|
|
# NOTE (NickLucche) To move blocks efficiently with NIXL, the expected
|
|
|
|
|
# KV memory layout is HND, as opposed to the default NHD. Note that it
|
|
|
|
|
# will only affects the strides. For MLA instead, we make require no
|
|
|
|
|
# such thing and resort to the standard layout.
|
|
|
|
|
if self.use_mla:
|
|
|
|
|
# MLA case.
|
|
|
|
|
self.num_blocks = first_kv_cache.shape[0]
|
|
|
|
|
block_rank = 2 # [block_size, latent_dim]
|
|
|
|
|
block_shape = first_kv_cache.shape[-block_rank:]
|
|
|
|
|
block_size, kv_latent_dim = block_shape
|
|
|
|
|
self.slot_size_bytes = kv_elem_size * kv_latent_dim
|
|
|
|
|
else:
|
|
|
|
|
# [2 (k and v), num_blocks, ...]
|
|
|
|
|
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
|
|
|
|
|
self.num_blocks = first_kv_cache.shape[1]
|
|
|
|
|
block_rank = 3 # [block_size, kv_heads, head_dim]
|
|
|
|
|
block_shape = first_kv_cache.shape[-block_rank:]
|
|
|
|
|
|
|
|
|
|
block_size, n_kv_heads, head_dim = block_shape
|
|
|
|
|
# head size in bytes.
|
|
|
|
|
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
|
|
|
|
|
assert block_size == self.block_size
|
|
|
|
|
# TODO(tms): self.block_len needs to be per-layer for sliding window,
|
|
|
|
|
# hybrid attn, etc
|
|
|
|
|
# block size in bytes
|
|
|
|
|
self.block_len = kv_elem_size * math.prod(block_shape)
|
|
|
|
|
|
|
|
|
|
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
|
|
|
|
|
first_kv_cache.shape)
|
|
|
|
|
logger.debug("Registering KV_Caches. use_mla: %s, shape %s",
|
|
|
|
|
self.use_mla, first_kv_cache.shape)
|
|
|
|
|
logger.debug("num_blocks: %s, block_shape: %s", self.num_blocks,
|
|
|
|
|
block_shape)
|
|
|
|
|
logger.debug("Per layer kv cache size: %s", first_kv_cache.shape)
|
|
|
|
|
@@ -489,7 +529,7 @@ class NixlConnectorWorker:
|
|
|
|
|
# (roughly 8KB vs 5KB).
|
|
|
|
|
for cache_or_caches in kv_caches.values():
|
|
|
|
|
# Normalize to always be a list of caches
|
|
|
|
|
cache_list = [cache_or_caches] if use_mla else cache_or_caches
|
|
|
|
|
cache_list = [cache_or_caches] if self.use_mla else cache_or_caches
|
|
|
|
|
for cache in cache_list:
|
|
|
|
|
base_addr = cache.data_ptr()
|
|
|
|
|
region_len = self.num_blocks * self.block_len
|
|
|
|
|
@@ -524,16 +564,37 @@ class NixlConnectorWorker:
|
|
|
|
|
logger.debug("Registering descs: %s", caches_data)
|
|
|
|
|
self.nixl_wrapper.register_memory(descs)
|
|
|
|
|
logger.debug("Done registering descs")
|
|
|
|
|
|
|
|
|
|
self._registered_descs.append(descs)
|
|
|
|
|
|
|
|
|
|
# Register local/src descr for NIXL xfer.
|
|
|
|
|
blocks_data = []
|
|
|
|
|
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
|
|
|
|
# NOTE With heter-TP, more blocks are prepared than what are
|
|
|
|
|
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
|
|
|
|
|
# could create fewer, but then _get_block_descs_ids needs to
|
|
|
|
|
# select agent_meta.num_blocks instead of self.num_blocks for
|
|
|
|
|
# local descr, and that makes handling regular flow less clean.
|
|
|
|
|
for block_id in range(self.num_blocks):
|
|
|
|
|
block_offset = block_id * self.block_len
|
|
|
|
|
addr = base_addr + block_offset
|
|
|
|
|
# (addr, len, device id)
|
|
|
|
|
blocks_data.append((addr, self.block_len, self.tp_rank))
|
|
|
|
|
logger.debug("Created %s blocks for src engine %s and rank %s",
|
|
|
|
|
len(blocks_data), self.engine_id, self.tp_rank)
|
|
|
|
|
|
|
|
|
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
|
|
|
|
# NIXL_INIT_AGENT to be used for preparations of local descs.
|
|
|
|
|
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
|
|
|
|
"NIXL_INIT_AGENT", descs)
|
|
|
|
|
|
|
|
|
|
# After KV Caches registered, listen for new connections.
|
|
|
|
|
metadata = NixlAgentMetadata(
|
|
|
|
|
engine_id=self.engine_id,
|
|
|
|
|
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
|
|
|
|
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
|
|
|
|
num_blocks=self.num_blocks,
|
|
|
|
|
)
|
|
|
|
|
tp_size=self.world_size,
|
|
|
|
|
block_len=self.block_len)
|
|
|
|
|
ready_event = threading.Event()
|
|
|
|
|
self._nixl_handshake_listener_t = threading.Thread(
|
|
|
|
|
target=self._nixl_handshake_listener,
|
|
|
|
|
@@ -543,50 +604,123 @@ class NixlConnectorWorker:
|
|
|
|
|
self._nixl_handshake_listener_t.start()
|
|
|
|
|
ready_event.wait()
|
|
|
|
|
|
|
|
|
|
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
|
|
|
|
|
def add_remote_agent(self,
|
|
|
|
|
nixl_agent_meta: NixlAgentMetadata,
|
|
|
|
|
remote_tp_rank: int = 0):
|
|
|
|
|
"""
|
|
|
|
|
Add the remote NIXL agent and prepare the descriptors for reading cache
|
|
|
|
|
blocks from remote.
|
|
|
|
|
|
|
|
|
|
In particular, handle both homogeneous and heterogeneous TP. The former
|
|
|
|
|
requires local rank_i to read from remote rank_i.
|
|
|
|
|
The latter, assuming D.world_size > P.world_size, requires that two or
|
|
|
|
|
more local TP worker share the xfer from a single TP worker.
|
|
|
|
|
|
|
|
|
|
Here's an example:
|
|
|
|
|
|
|
|
|
|
rank_offset p_remote_tp_rank
|
|
|
|
|
(kv split no)
|
|
|
|
|
--------------------------------
|
|
|
|
|
0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
|
|
|
|
|
/
|
|
|
|
|
1 0 Worker1 ---- 2nd half of KV -----/
|
|
|
|
|
|
|
|
|
|
0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
|
|
|
|
|
/
|
|
|
|
|
1 1 Worker3 ---- 2nd half of KV -----/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Decoder TP workers Prefix TP workers
|
|
|
|
|
(world_size=4) (world_size=2)
|
|
|
|
|
tp_ratio = 4 // 2 = 2
|
|
|
|
|
|
|
|
|
|
Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, kv_heads, block_size, head_dim]
|
|
|
|
|
then D-Worker_j has [2, num_blocksD, kv_heads//tp_ratio, block_size, head_dim]. Mind the "HND" layout format.
|
|
|
|
|
Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
|
|
|
|
|
first heads from all the slots of all the blocks. D-Worker1 will do the same, but reading the second split
|
|
|
|
|
along the kv_heads dimension, and so forth until "tp_ratio" D TP workers have pulled from P-Worker0.
|
|
|
|
|
|
|
|
|
|
Note that the above will also hold true for the homogeneous TP case, where tp_ratio evaluates to 1.
|
|
|
|
|
|
|
|
|
|
Regarding MLA case, the cache is replicated across TP workers so the rank_offset will just always be 0
|
|
|
|
|
so that the whole cache is shared by "tp_ratio" D TP workers.
|
|
|
|
|
""" # noqa: E501
|
|
|
|
|
engine_id = nixl_agent_meta.engine_id
|
|
|
|
|
assert engine_id != self.engine_id, "Conflict engine id found!"
|
|
|
|
|
if engine_id in self._remote_agents:
|
|
|
|
|
# TODO re-evaluate refreshing for scaling/recovery
|
|
|
|
|
if remote_tp_rank in self._remote_agents.get(engine_id, ()):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
|
|
|
|
|
nixl_agent_meta.agent_metadata)
|
|
|
|
|
self.kv_caches_base_addr[
|
|
|
|
|
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
|
|
|
|
if engine_id in self._tp_size:
|
|
|
|
|
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
|
|
|
|
|
else:
|
|
|
|
|
self._tp_size[engine_id] = nixl_agent_meta.tp_size
|
|
|
|
|
self._remote_agents[engine_id][
|
|
|
|
|
remote_tp_rank] = self.nixl_wrapper.add_remote_agent(
|
|
|
|
|
nixl_agent_meta.agent_metadata)
|
|
|
|
|
|
|
|
|
|
# Number of D TP workers reading from a single P TP worker. This is
|
|
|
|
|
# 1 when P and D `--tensor-parallel-size` match.
|
|
|
|
|
assert self._tp_size[self.engine_id] % self._tp_size[engine_id] == 0, \
|
|
|
|
|
"Local TP size must be divisible by remote TP size."
|
|
|
|
|
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
|
|
|
|
|
assert tp_ratio > 0, "Decode TP cannot be smaller than"
|
|
|
|
|
" prefill TP"
|
|
|
|
|
if self.use_mla:
|
|
|
|
|
# With MLA the only difference is in the number of blocks.
|
|
|
|
|
remote_block_size = nixl_agent_meta.block_len / (
|
|
|
|
|
self.slot_size_bytes)
|
|
|
|
|
assert self.block_len == nixl_agent_meta.block_len
|
|
|
|
|
else:
|
|
|
|
|
remote_block_size = nixl_agent_meta.block_len / (
|
|
|
|
|
self.slot_size_bytes * tp_ratio)
|
|
|
|
|
|
|
|
|
|
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, \
|
|
|
|
|
"Remote P worker KV layer cache must be of shape [2, N, \
|
|
|
|
|
local_kv_heads*tp_ratio, block_size, head_dim] and same dtype."
|
|
|
|
|
|
|
|
|
|
assert self.block_size == remote_block_size, "Remote P worker with \
|
|
|
|
|
different block size is not supported"
|
|
|
|
|
|
|
|
|
|
assert self.num_blocks >= nixl_agent_meta.num_blocks
|
|
|
|
|
|
|
|
|
|
# Create dst descs and xfer side handles. TP workers have same #blocks.
|
|
|
|
|
if engine_id in self.dst_num_blocks:
|
|
|
|
|
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
|
|
|
|
|
else:
|
|
|
|
|
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
|
|
|
|
|
|
|
|
|
# Create src descs and xfer side handles.
|
|
|
|
|
blocks_data = []
|
|
|
|
|
for base_addr in self.kv_caches_base_addr[self.engine_id]:
|
|
|
|
|
for block_id in range(self.num_blocks):
|
|
|
|
|
block_offset = block_id * self.block_len
|
|
|
|
|
# (addr, len, device id)
|
|
|
|
|
blocks_data.append(
|
|
|
|
|
(base_addr + block_offset, self.block_len, self.tp_rank))
|
|
|
|
|
logger.debug("Created %s blocks for src engine %s and tp_rank %s",
|
|
|
|
|
len(blocks_data), self.engine_id, self.tp_rank)
|
|
|
|
|
# With homogeneous TP, D pulls the whole kv cache from corresponding
|
|
|
|
|
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
|
|
|
|
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
|
|
|
|
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
|
|
|
|
p_remote_tp_rank = self.tp_rank // tp_ratio
|
|
|
|
|
# Only register the remote's descriptors if current rank pulls from it.
|
|
|
|
|
if p_remote_tp_rank == remote_tp_rank:
|
|
|
|
|
self.kv_caches_base_addr[
|
|
|
|
|
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
|
|
|
|
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
|
|
|
|
if not self.use_mla else 0
|
|
|
|
|
# Register all remote blocks, but only the corresponding kv heads.
|
|
|
|
|
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
|
|
|
|
for block_id in range(nixl_agent_meta.num_blocks):
|
|
|
|
|
block_offset = block_id * nixl_agent_meta.block_len
|
|
|
|
|
# For each block, grab the heads chunk belonging to rank_i
|
|
|
|
|
# of size remote_nheads // tp_ratio, which correspond to
|
|
|
|
|
# self.block_len == remote_block_len//tp_ratio bytes.
|
|
|
|
|
addr = base_addr + block_offset + rank_offset
|
|
|
|
|
# (addr, len, device id)
|
|
|
|
|
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
|
|
|
|
logger.debug(
|
|
|
|
|
"Created %s blocks for dst engine %s with remote rank %s and " \
|
|
|
|
|
"local rank %s",
|
|
|
|
|
len(blocks_data), engine_id, remote_tp_rank, self.tp_rank)
|
|
|
|
|
|
|
|
|
|
# Register with NIXL.
|
|
|
|
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
|
|
|
|
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
|
|
|
|
|
"NIXL_INIT_AGENT", descs)
|
|
|
|
|
|
|
|
|
|
# Create dst descs and xfer side handles.
|
|
|
|
|
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
|
|
|
|
|
blocks_data = []
|
|
|
|
|
for base_addr in self.kv_caches_base_addr[engine_id]:
|
|
|
|
|
for block_id in range(nixl_agent_meta.num_blocks):
|
|
|
|
|
block_offset = block_id * self.block_len
|
|
|
|
|
# (addr, len, device id)
|
|
|
|
|
blocks_data.append(
|
|
|
|
|
(base_addr + block_offset, self.block_len, self.tp_rank))
|
|
|
|
|
logger.debug("Created %s blocks for dst engine %s and tp_rank %s",
|
|
|
|
|
len(blocks_data), engine_id, self.tp_rank)
|
|
|
|
|
|
|
|
|
|
# Register with NIXL.
|
|
|
|
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
|
|
|
|
self.dst_xfer_side_handles[
|
|
|
|
|
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
|
|
|
|
self._remote_agents[engine_id], descs)
|
|
|
|
|
# Register with NIXL.
|
|
|
|
|
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
|
|
|
|
self.dst_xfer_side_handles[
|
|
|
|
|
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
|
|
|
|
self._remote_agents[engine_id][remote_tp_rank], descs)
|
|
|
|
|
|
|
|
|
|
def get_finished(self) -> tuple[set[str], set[str]]:
|
|
|
|
|
"""
|
|
|
|
|
@@ -654,16 +788,25 @@ class NixlConnectorWorker:
|
|
|
|
|
return done_sending, done_recving
|
|
|
|
|
|
|
|
|
|
def _get_new_notifs(self) -> set[str]:
|
|
|
|
|
"""Get req_ids which got a remote xfer message."""
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
Get req_ids which got a remote xfer message. When multiple consumers
|
|
|
|
|
are reading from the same producer (heterogeneous TP scenario), wait
|
|
|
|
|
for all consumers to be done pulling.
|
|
|
|
|
"""
|
|
|
|
|
notified_req_ids: set[str] = set()
|
|
|
|
|
for req_ids in self.nixl_wrapper.get_new_notifs().values():
|
|
|
|
|
for req_id in req_ids:
|
|
|
|
|
assert req_id not in notified_req_ids
|
|
|
|
|
notified_req_ids.add(req_id.decode("utf-8"))
|
|
|
|
|
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
|
|
|
|
for notif in notifs:
|
|
|
|
|
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
|
|
|
|
self.consumer_notification_counts_by_req[req_id] += 1
|
|
|
|
|
# Wait all consumers (D) to be done reading before freeing.
|
|
|
|
|
if self.consumer_notification_counts_by_req[req_id] == int(
|
|
|
|
|
tp_ratio):
|
|
|
|
|
notified_req_ids.add(req_id)
|
|
|
|
|
del self.consumer_notification_counts_by_req[req_id]
|
|
|
|
|
return notified_req_ids
|
|
|
|
|
|
|
|
|
|
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
|
|
|
|
|
def _pop_done_transfers(
|
|
|
|
|
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
|
|
|
|
|
"""
|
|
|
|
|
Pop completed xfers by checking for DONE state.
|
|
|
|
|
Args:
|
|
|
|
|
@@ -673,23 +816,17 @@ class NixlConnectorWorker:
|
|
|
|
|
"""
|
|
|
|
|
done_req_ids: set[str] = set()
|
|
|
|
|
for req_id, handles in list(transfers.items()):
|
|
|
|
|
running_reqs = []
|
|
|
|
|
for handle in handles:
|
|
|
|
|
for handle, xfer_stime in handles:
|
|
|
|
|
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
|
|
|
|
if xfer_state == "DONE":
|
|
|
|
|
# TODO ptarasiewicz: why abort is throwing errors?
|
|
|
|
|
# self.nixl_wrapper.release_xfer_handle(handle)
|
|
|
|
|
self.nixl_wrapper.release_xfer_handle(handle)
|
|
|
|
|
done_req_ids.add(req_id)
|
|
|
|
|
del transfers[req_id]
|
|
|
|
|
elif xfer_state == "PROC":
|
|
|
|
|
continue
|
|
|
|
|
if xfer_state == "PROC":
|
|
|
|
|
running_reqs.append(handle)
|
|
|
|
|
else:
|
|
|
|
|
raise RuntimeError("Transfer failed with state %s",
|
|
|
|
|
xfer_state)
|
|
|
|
|
if len(running_reqs) == 0:
|
|
|
|
|
done_req_ids.add(req_id)
|
|
|
|
|
del transfers[req_id]
|
|
|
|
|
else:
|
|
|
|
|
transfers[req_id] = running_reqs
|
|
|
|
|
return done_req_ids
|
|
|
|
|
|
|
|
|
|
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
|
|
|
|
@@ -735,13 +872,19 @@ class NixlConnectorWorker:
|
|
|
|
|
# saturate IB with heterogeneous TP sizes. We should remove the staging
|
|
|
|
|
# blocks until we are ready.
|
|
|
|
|
|
|
|
|
|
# Number of D TP workers that will read from dst P. Propagate tp_ratio
|
|
|
|
|
# on notification so that dst worker can wait before freeing blocks.
|
|
|
|
|
tp_ratio = self._tp_size[
|
|
|
|
|
self.engine_id] // self._tp_size[dst_engine_id]
|
|
|
|
|
notif_id = f"{request_id}:{tp_ratio}".encode()
|
|
|
|
|
|
|
|
|
|
# Full prefix cache hit: do not need to read remote blocks,
|
|
|
|
|
# just notify P worker that we have the blocks we need.
|
|
|
|
|
num_local_blocks = len(local_block_ids)
|
|
|
|
|
if num_local_blocks == 0:
|
|
|
|
|
agent_name = self._remote_agents[dst_engine_id]
|
|
|
|
|
self.nixl_wrapper.send_notif(agent_name,
|
|
|
|
|
notif_msg=request_id.encode("utf-8"))
|
|
|
|
|
remote_rank = self.tp_rank // tp_ratio
|
|
|
|
|
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
|
|
|
|
self.nixl_wrapper.send_notif(agent_name, notif_msg=notif_id)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Partial prefix cache hit: just read uncomputed blocks.
|
|
|
|
|
@@ -754,6 +897,10 @@ class NixlConnectorWorker:
|
|
|
|
|
local_xfer_side_handle = self.src_xfer_side_handle
|
|
|
|
|
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
|
|
|
|
|
|
|
|
|
|
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
|
|
|
|
|
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
|
|
|
|
|
# workers will issue xfers to parts of the P worker remote kv caches.
|
|
|
|
|
|
|
|
|
|
# Get descs ids.
|
|
|
|
|
local_block_descs_ids: list[int] = []
|
|
|
|
|
remote_block_descs_ids: list[int] = []
|
|
|
|
|
@@ -797,14 +944,16 @@ class NixlConnectorWorker:
|
|
|
|
|
local_block_descs_ids,
|
|
|
|
|
remote_xfer_side_handle,
|
|
|
|
|
remote_block_descs_ids,
|
|
|
|
|
notif_msg=request_id.encode("utf-8"),
|
|
|
|
|
notif_msg=notif_id,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Begin async xfer.
|
|
|
|
|
self.nixl_wrapper.transfer(handle)
|
|
|
|
|
|
|
|
|
|
# Use handle to check completion in future step().
|
|
|
|
|
self._recving_transfers[request_id].append(handle)
|
|
|
|
|
# TODO (NickLucche) surface xfer elapsed time
|
|
|
|
|
self._recving_transfers[request_id].append(
|
|
|
|
|
(handle, time.perf_counter()))
|
|
|
|
|
|
|
|
|
|
def _get_block_descs_ids(self,
|
|
|
|
|
engine_id: str,
|
|
|
|
|
@@ -815,7 +964,6 @@ class NixlConnectorWorker:
|
|
|
|
|
If layer_idx is provided, we use the region_ids for the given layer.
|
|
|
|
|
Otherwise, we use all regions.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if layer_idx is None:
|
|
|
|
|
region_ids = range(self.num_regions)
|
|
|
|
|
else:
|
|
|
|
|
|