[V1][P/D]Enhance Performance and code readability for P2pNcclConnector (#20906)

Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
Zhonghua Deng
2025-07-17 13:13:00 +08:00
committed by GitHub
parent 76b494444f
commit 8a4e5c5f3c
4 changed files with 262 additions and 252 deletions

View File

@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
P2pNcclEngine)
from vllm.distributed.parallel_state import get_world_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
def extract_kv_from_layer(
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]
connector_metadata = self._get_connector_metadata()
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
for request in connector_metadata.requests:
request_id = request.request_id
ip, port = self.parse_request_id(request_id, True)
remote_address = ip + ":" + str(port + self._rank)
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
kv_cache, remote_address)
self.p2p_nccl_engine.send_tensor(
request_id + "#" + layer_name, kv_layer, remote_address,
request.slot_mapping,
isinstance(attn_metadata, MLACommonMetadata))
def wait_for_save(self):
if self.is_producer:
@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
assert self.p2p_nccl_engine is not None
forward_context: ForwardContext = get_forward_context()
no_compile_layers = (
self._vllm_config.compilation_config.static_forward_context)
return self.p2p_nccl_engine.get_finished(finished_req_ids,
forward_context)
no_compile_layers)
# ==============================
# Scheduler-side methods
@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
block_ids=block_ids,
block_size=self._block_size)
# Requests loaded asynchronously are not in the scheduler_output.
# for request_id in self._requests_need_load:
# request, block_ids = self._requests_need_load[request_id]
# meta.add_request(request_id=request.request_id,
# token_ids=request.prompt_token_ids,
# block_ids=block_ids,
# block_size=self._block_size)
self._requests_need_load.clear()
return meta

View File

@@ -8,7 +8,8 @@ import time
import typing
from collections import deque
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Optional
from dataclasses import dataclass
from typing import Any, Optional
import msgpack
import torch
@@ -21,9 +22,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
TensorMemoryPool)
from vllm.utils import current_stream, get_ip
if TYPE_CHECKING:
from vllm.forward_context import ForwardContext
logger = logging.getLogger(__name__)
DEFAULT_MEM_POOL_SIZE_GB = 32
@@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str):
os.environ.pop(var, None)
@dataclass
class SendQueueItem:
tensor_id: str
remote_address: str
tensor: torch.Tensor
slot_mapping: torch.Tensor
is_mla: bool
class P2pNcclEngine:
def __init__(self,
@@ -112,24 +119,26 @@ class P2pNcclEngine:
self.send_stream = torch.cuda.Stream()
self.recv_stream = torch.cuda.Stream()
mem_pool_size_gb = self.config.get_from_extra_config(
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
1024**3) # GB
mem_pool_size_gb = float(
self.config.get_from_extra_config("mem_pool_size_gb",
DEFAULT_MEM_POOL_SIZE_GB))
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
1024**3)) # GB
# The sending type includes tree mutually exclusive options:
# PUT, GET, PUT_ASYNC.
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
self.send_type = self.config.get_from_extra_config(
"send_type", "PUT_ASYNC")
if self.send_type == "GET":
# tensor_id: torch.Tensor
self.send_store: dict[str, torch.Tensor] = {}
else:
# PUT or PUT_ASYNC
# tensor_id: torch.Tensor
self.send_queue: deque[list[Any]] = deque()
self.send_queue: deque[SendQueueItem] = deque()
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
if self.send_type == "PUT_ASYNC":
self._send_thread = threading.Thread(target=self._send_async,
self._send_thread = threading.Thread(target=self.send_async,
daemon=True)
self._send_thread.start()
@@ -146,13 +155,12 @@ class P2pNcclEngine:
"nccl_num_channels", "8")
self._listener_thread = threading.Thread(
target=self._listen_for_requests, daemon=True)
target=self.listen_for_requests, daemon=True)
self._listener_thread.start()
self._ping_thread = None
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping,
daemon=True)
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
self._ping_thread.start()
logger.info(
@@ -162,7 +170,7 @@ class P2pNcclEngine:
self.http_address, self.zmq_address, self.proxy_address,
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
def _create_connect(self, remote_address: typing.Optional[str] = None):
def create_connect(self, remote_address: typing.Optional[str] = None):
assert remote_address is not None
if remote_address not in self.socks:
sock = self.context.socket(zmq.DEALER)
@@ -184,7 +192,7 @@ class P2pNcclEngine:
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
self.zmq_address, remote_address, rank)
return self.socks[remote_address], self.comms[remote_address]
@@ -194,44 +202,54 @@ class P2pNcclEngine:
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
slot_mapping: torch.Tensor = None,
is_mla: bool = False,
) -> bool:
if remote_address is None:
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.recv_store_cv.notify()
return True
else:
if self.send_type == "PUT":
return self._send_sync(tensor_id, tensor, remote_address)
elif self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append([tensor_id, remote_address, tensor])
self.send_queue_cv.notify()
else: # GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size,
self.buffer_size, oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
remote_address, tensor_id, tensor_size, tensor.shape,
self.rank, self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
item = SendQueueItem(tensor_id=tensor_id,
remote_address=remote_address,
tensor=tensor,
slot_mapping=slot_mapping,
is_mla=is_mla)
if self.send_type == "PUT":
return self.send_sync(item)
if self.send_type == "PUT_ASYNC":
with self.send_queue_cv:
self.send_queue.append(item)
self.send_queue_cv.notify()
return True
# GET
with self.send_store_cv:
tensor_size = tensor.element_size() * tensor.numel()
while (self.buffer_size + tensor_size
> self.buffer_size_threshold):
oldest_tenser_id = next(iter(self.send_store))
oldest_tenser = self.send_store.pop(oldest_tenser_id)
oldest_tenser_size = oldest_tenser.element_size(
) * oldest_tenser.numel()
self.buffer_size -= oldest_tenser_size
logger.info(
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
remote_address, tensor_id, tensor_size, self.buffer_size,
oldest_tenser_size, self.rank)
self.send_store[tensor_id] = tensor
self.buffer_size += tensor_size
logger.debug(
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
tensor_id, tensor_size, tensor.shape, self.rank,
self.buffer_size,
self.buffer_size / self.buffer_size_threshold * 100)
return True
def recv_tensor(
@@ -267,7 +285,7 @@ class P2pNcclEngine:
return None
if remote_address not in self.socks:
self._create_connect(remote_address)
self.create_connect(remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
@@ -282,121 +300,121 @@ class P2pNcclEngine:
remote_address, tensor_id, data["ret"])
return None
tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device)
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(torch, data["dtype"]),
device=self.device)
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
return tensor
def _listen_for_requests(self):
def listen_for_requests(self):
while True:
socks = dict(self.poller.poll())
if self.router_socket in socks:
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info(
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(), rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart(
[remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
if self.router_socket not in socks:
continue
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart(
[remote_address, b"1"])
tensor = None
remote_address, message = self.router_socket.recv_multipart()
data = msgpack.loads(message)
if data["cmd"] == "NEW":
unique_id = self.nccl.unique_id_from_bytes(
bytes(data["unique_id"]))
with torch.cuda.device(self.device):
rank = 1
with set_p2p_nccl_context(self.nccl_num_channels):
comm: ncclComm_t = self.nccl.ncclCommInitRank(
2, unique_id, rank)
self.comms[remote_address.decode()] = (comm, rank)
logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
self.zmq_address, remote_address.decode(),
rank)
elif data["cmd"] == "PUT":
tensor_id = data["tensor_id"]
try:
with torch.cuda.stream(self.recv_stream):
tensor = torch.empty(data["shape"],
dtype=getattr(
torch, data["dtype"]),
device=self.device)
self.router_socket.send_multipart([remote_address, b"0"])
comm, rank = self.comms[remote_address.decode()]
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
tensor_size = tensor.element_size() * tensor.numel()
if (self.buffer_size + tensor_size
> self.buffer_size_threshold):
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address,
remote_address.decode(), data)
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self._have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype":
str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self._have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self._send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
except torch.cuda.OutOfMemoryError:
self.router_socket.send_multipart([remote_address, b"1"])
tensor = None
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
"data:%s", self.zmq_address, remote_address.decode(),
data)
def _have_sent_tensor_id(self, tensor_id: str):
with self.recv_store_cv:
self.recv_store[tensor_id] = tensor
self.have_received_tensor_id(tensor_id)
self.recv_store_cv.notify()
elif data["cmd"] == "GET":
tensor_id = data["tensor_id"]
with self.send_store_cv:
tensor = self.send_store.pop(tensor_id, None)
if tensor is not None:
data = {
"ret": 0,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
# LRU
self.send_store[tensor_id] = tensor
self.have_sent_tensor_id(tensor_id)
else:
data = {"ret": 1}
self.router_socket.send_multipart(
[remote_address, msgpack.dumps(data)])
if data["ret"] == 0:
comm, rank = self.comms[remote_address.decode()]
self.send(comm, tensor.to(self.device), rank ^ 1,
self.send_stream)
else:
logger.warning(
"🚧Unexpected, Received message from %s, data:%s",
remote_address, data)
def have_sent_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.send_request_id_to_tensor_ids:
self.send_request_id_to_tensor_ids[request_id] = set()
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
def _have_received_tensor_id(self, tensor_id: str):
def have_received_tensor_id(self, tensor_id: str):
request_id = tensor_id.split('#')[0]
if request_id not in self.recv_request_id_to_tensor_ids:
self.recv_request_id_to_tensor_ids[request_id] = set()
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
def _send_async(self):
def send_async(self):
while True:
with self.send_queue_cv:
while not self.send_queue:
self.send_queue_cv.wait()
tensor_id, remote_address, tensor = self.send_queue.popleft()
item = self.send_queue.popleft()
if not self.send_queue:
self.send_queue_cv.notify()
self._send_sync(tensor_id, tensor, remote_address)
self.send_sync(item)
def wait_for_sent(self):
if self.send_type == "PUT_ASYNC":
@@ -409,22 +427,21 @@ class P2pNcclEngine:
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
" to be empty, rank:%d", duration * 1000, self.rank)
def _send_sync(
self,
tensor_id: str,
tensor: torch.Tensor,
remote_address: typing.Optional[str] = None,
) -> bool:
if remote_address is None:
def send_sync(self, item: SendQueueItem) -> bool:
if item.remote_address is None:
return False
if remote_address not in self.socks:
self._create_connect(remote_address)
if item.remote_address not in self.socks:
self.create_connect(item.remote_address)
sock = self.socks[remote_address]
comm, rank = self.comms[remote_address]
with self.send_stream:
tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
item.slot_mapping)
sock = self.socks[item.remote_address]
comm, rank = self.comms[item.remote_address]
data = {
"cmd": "PUT",
"tensor_id": tensor_id,
"tensor_id": item.tensor_id,
"shape": tensor.shape,
"dtype": str(tensor.dtype).replace("torch.", "")
}
@@ -435,20 +452,21 @@ class P2pNcclEngine:
logger.error(
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
self.zmq_address, remote_address, rank, data, tensor.shape,
self.zmq_address, item.remote_address, rank, data,
tensor.shape,
tensor.element_size() * tensor.numel() / 1024**3,
response.decode())
return False
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
if self.send_type == "PUT_ASYNC":
self._have_sent_tensor_id(tensor_id)
self.have_sent_tensor_id(item.tensor_id)
return True
def get_finished(
self, finished_req_ids: set[str], forward_context: "ForwardContext"
self, finished_req_ids: set[str], no_compile_layers
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
@@ -463,7 +481,7 @@ class P2pNcclEngine:
# Clear the buffer upon request completion.
for request_id in finished_req_ids:
for layer_name in forward_context.no_compile_layers:
for layer_name in no_compile_layers:
tensor_id = request_id + "#" + layer_name
if tensor_id in self.recv_store:
with self.recv_store_cv:
@@ -472,7 +490,6 @@ class P2pNcclEngine:
request_id, None)
self.recv_request_id_to_tensor_ids.pop(
request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.pool.free(addr)
@@ -485,7 +502,7 @@ class P2pNcclEngine:
return finished_sending or None, finished_recving or None
def _ping(self):
def ping(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
@@ -499,7 +516,7 @@ class P2pNcclEngine:
sock.send(msgpack.dumps(data))
time.sleep(3)
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
@@ -512,7 +529,7 @@ class P2pNcclEngine:
comm, cudaStream_t(stream.cuda_stream))
stream.synchronize()
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
assert tensor.device == self.device, (
f"this nccl communicator is created to work on {self.device}, "
f"but the input tensor is on {tensor.device}")
@@ -531,3 +548,21 @@ class P2pNcclEngine:
self._send_thread.join()
if self._ping_thread is not None:
self._ping_thread.join()
@staticmethod
def extract_kv_from_layer(
is_mla: bool,
layer: torch.Tensor,
slot_mapping: torch.Tensor,
) -> torch.Tensor:
"""Extract the KV cache from the layer.
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if is_mla:
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
num_pages, page_size = layer.shape[1], layer.shape[2]
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
...]