[V1][P/D]Enhance Performance and code readability for P2pNcclConnector (#20906)
Signed-off-by: Abatom <abzhonghua@gmail.com>
This commit is contained in:
@@ -13,7 +13,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
|
||||
P2pNcclEngine)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -238,32 +237,16 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
def extract_kv_from_layer(
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if isinstance(attn_metadata, MLACommonMetadata):
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
|
||||
...]
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||
...]
|
||||
|
||||
connector_metadata = self._get_connector_metadata()
|
||||
assert isinstance(connector_metadata, P2pNcclConnectorMetadata)
|
||||
for request in connector_metadata.requests:
|
||||
request_id = request.request_id
|
||||
ip, port = self.parse_request_id(request_id, True)
|
||||
remote_address = ip + ":" + str(port + self._rank)
|
||||
kv_cache = extract_kv_from_layer(kv_layer, request.slot_mapping)
|
||||
self.p2p_nccl_engine.send_tensor(request_id + "#" + layer_name,
|
||||
kv_cache, remote_address)
|
||||
self.p2p_nccl_engine.send_tensor(
|
||||
request_id + "#" + layer_name, kv_layer, remote_address,
|
||||
request.slot_mapping,
|
||||
isinstance(attn_metadata, MLACommonMetadata))
|
||||
|
||||
def wait_for_save(self):
|
||||
if self.is_producer:
|
||||
@@ -286,9 +269,10 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
|
||||
assert self.p2p_nccl_engine is not None
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
no_compile_layers = (
|
||||
self._vllm_config.compilation_config.static_forward_context)
|
||||
return self.p2p_nccl_engine.get_finished(finished_req_ids,
|
||||
forward_context)
|
||||
no_compile_layers)
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
@@ -418,14 +402,6 @@ class P2pNcclConnector(KVConnectorBase_V1):
|
||||
block_ids=block_ids,
|
||||
block_size=self._block_size)
|
||||
|
||||
# Requests loaded asynchronously are not in the scheduler_output.
|
||||
# for request_id in self._requests_need_load:
|
||||
# request, block_ids = self._requests_need_load[request_id]
|
||||
# meta.add_request(request_id=request.request_id,
|
||||
# token_ids=request.prompt_token_ids,
|
||||
# block_ids=block_ids,
|
||||
# block_size=self._block_size)
|
||||
|
||||
self._requests_need_load.clear()
|
||||
return meta
|
||||
|
||||
|
||||
@@ -8,7 +8,8 @@ import time
|
||||
import typing
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import msgpack
|
||||
import torch
|
||||
@@ -21,9 +22,6 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import
|
||||
TensorMemoryPool)
|
||||
from vllm.utils import current_stream, get_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_MEM_POOL_SIZE_GB = 32
|
||||
@@ -59,6 +57,15 @@ def set_p2p_nccl_context(num_channels: str):
|
||||
os.environ.pop(var, None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendQueueItem:
|
||||
tensor_id: str
|
||||
remote_address: str
|
||||
tensor: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
is_mla: bool
|
||||
|
||||
|
||||
class P2pNcclEngine:
|
||||
|
||||
def __init__(self,
|
||||
@@ -112,24 +119,26 @@ class P2pNcclEngine:
|
||||
self.send_stream = torch.cuda.Stream()
|
||||
self.recv_stream = torch.cuda.Stream()
|
||||
|
||||
mem_pool_size_gb = self.config.get_from_extra_config(
|
||||
"mem_pool_size_gb", DEFAULT_MEM_POOL_SIZE_GB)
|
||||
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb) *
|
||||
1024**3) # GB
|
||||
mem_pool_size_gb = float(
|
||||
self.config.get_from_extra_config("mem_pool_size_gb",
|
||||
DEFAULT_MEM_POOL_SIZE_GB))
|
||||
self.pool = TensorMemoryPool(max_block_size=int(mem_pool_size_gb *
|
||||
1024**3)) # GB
|
||||
|
||||
# The sending type includes tree mutually exclusive options:
|
||||
# PUT, GET, PUT_ASYNC.
|
||||
self.send_type = self.config.get_from_extra_config("send_type", "PUT")
|
||||
self.send_type = self.config.get_from_extra_config(
|
||||
"send_type", "PUT_ASYNC")
|
||||
if self.send_type == "GET":
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_store: dict[str, torch.Tensor] = {}
|
||||
else:
|
||||
# PUT or PUT_ASYNC
|
||||
# tensor_id: torch.Tensor
|
||||
self.send_queue: deque[list[Any]] = deque()
|
||||
self.send_queue: deque[SendQueueItem] = deque()
|
||||
self.send_request_id_to_tensor_ids: dict[str, set[str]] = {}
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._send_thread = threading.Thread(target=self._send_async,
|
||||
self._send_thread = threading.Thread(target=self.send_async,
|
||||
daemon=True)
|
||||
self._send_thread.start()
|
||||
|
||||
@@ -146,13 +155,12 @@ class P2pNcclEngine:
|
||||
"nccl_num_channels", "8")
|
||||
|
||||
self._listener_thread = threading.Thread(
|
||||
target=self._listen_for_requests, daemon=True)
|
||||
target=self.listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
|
||||
self._ping_thread = None
|
||||
if port_offset == 0 and self.proxy_address != "":
|
||||
self._ping_thread = threading.Thread(target=self._ping,
|
||||
daemon=True)
|
||||
self._ping_thread = threading.Thread(target=self.ping, daemon=True)
|
||||
self._ping_thread.start()
|
||||
|
||||
logger.info(
|
||||
@@ -162,7 +170,7 @@ class P2pNcclEngine:
|
||||
self.http_address, self.zmq_address, self.proxy_address,
|
||||
self.send_type, self.buffer_size_threshold, self.nccl_num_channels)
|
||||
|
||||
def _create_connect(self, remote_address: typing.Optional[str] = None):
|
||||
def create_connect(self, remote_address: typing.Optional[str] = None):
|
||||
assert remote_address is not None
|
||||
if remote_address not in self.socks:
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
@@ -184,7 +192,7 @@ class P2pNcclEngine:
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address] = (comm, rank)
|
||||
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank: %s",
|
||||
logger.info("🤝ncclCommInitRank Success, %s👉%s, MyRank:%s",
|
||||
self.zmq_address, remote_address, rank)
|
||||
|
||||
return self.socks[remote_address], self.comms[remote_address]
|
||||
@@ -194,44 +202,54 @@ class P2pNcclEngine:
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
slot_mapping: torch.Tensor = None,
|
||||
is_mla: bool = False,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.recv_store_cv.notify()
|
||||
return True
|
||||
else:
|
||||
if self.send_type == "PUT":
|
||||
return self._send_sync(tensor_id, tensor, remote_address)
|
||||
elif self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append([tensor_id, remote_address, tensor])
|
||||
self.send_queue_cv.notify()
|
||||
else: # GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
oldest_tenser_id = next(iter(self.send_store))
|
||||
oldest_tenser = self.send_store.pop(oldest_tenser_id)
|
||||
oldest_tenser_size = oldest_tenser.element_size(
|
||||
) * oldest_tenser.numel()
|
||||
self.buffer_size -= oldest_tenser_size
|
||||
logger.info(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size,
|
||||
self.buffer_size, oldest_tenser_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)",
|
||||
remote_address, tensor_id, tensor_size, tensor.shape,
|
||||
self.rank, self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100)
|
||||
item = SendQueueItem(tensor_id=tensor_id,
|
||||
remote_address=remote_address,
|
||||
tensor=tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
is_mla=is_mla)
|
||||
|
||||
if self.send_type == "PUT":
|
||||
return self.send_sync(item)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
with self.send_queue_cv:
|
||||
self.send_queue.append(item)
|
||||
self.send_queue_cv.notify()
|
||||
return True
|
||||
|
||||
# GET
|
||||
with self.send_store_cv:
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
while (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
oldest_tenser_id = next(iter(self.send_store))
|
||||
oldest_tenser = self.send_store.pop(oldest_tenser_id)
|
||||
oldest_tenser_size = oldest_tenser.element_size(
|
||||
) * oldest_tenser.numel()
|
||||
self.buffer_size -= oldest_tenser_size
|
||||
logger.info(
|
||||
"⛔[GET]Send to %s, tensor_id:%s, tensor_size:%d,"
|
||||
" buffer_size:%d, oldest_tenser_size:%d, rank:%d",
|
||||
remote_address, tensor_id, tensor_size, self.buffer_size,
|
||||
oldest_tenser_size, self.rank)
|
||||
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.buffer_size += tensor_size
|
||||
logger.debug(
|
||||
"🔵[GET]Send to %s, tensor_id:%s, tensor_size:%d, "
|
||||
"shape:%s, rank:%d, buffer_size:%d(%.2f%%)", remote_address,
|
||||
tensor_id, tensor_size, tensor.shape, self.rank,
|
||||
self.buffer_size,
|
||||
self.buffer_size / self.buffer_size_threshold * 100)
|
||||
return True
|
||||
|
||||
def recv_tensor(
|
||||
@@ -267,7 +285,7 @@ class P2pNcclEngine:
|
||||
return None
|
||||
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
self.create_connect(remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
@@ -282,121 +300,121 @@ class P2pNcclEngine:
|
||||
remote_address, tensor_id, data["ret"])
|
||||
return None
|
||||
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device)
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(torch, data["dtype"]),
|
||||
device=self.device)
|
||||
|
||||
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
|
||||
return tensor
|
||||
|
||||
def _listen_for_requests(self):
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
socks = dict(self.poller.poll())
|
||||
if self.router_socket in socks:
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.nccl.unique_id_from_bytes(
|
||||
bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info(
|
||||
"🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address, remote_address.decode(), rank)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(
|
||||
torch, data["dtype"]),
|
||||
device=self.device)
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self._recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d", self.zmq_address,
|
||||
remote_address.decode(), data, addr)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
if self.router_socket not in socks:
|
||||
continue
|
||||
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, b"1"])
|
||||
tensor = None
|
||||
remote_address, message = self.router_socket.recv_multipart()
|
||||
data = msgpack.loads(message)
|
||||
if data["cmd"] == "NEW":
|
||||
unique_id = self.nccl.unique_id_from_bytes(
|
||||
bytes(data["unique_id"]))
|
||||
with torch.cuda.device(self.device):
|
||||
rank = 1
|
||||
with set_p2p_nccl_context(self.nccl_num_channels):
|
||||
comm: ncclComm_t = self.nccl.ncclCommInitRank(
|
||||
2, unique_id, rank)
|
||||
self.comms[remote_address.decode()] = (comm, rank)
|
||||
logger.info("🤝ncclCommInitRank Success, %s👈%s, MyRank:%s",
|
||||
self.zmq_address, remote_address.decode(),
|
||||
rank)
|
||||
elif data["cmd"] == "PUT":
|
||||
tensor_id = data["tensor_id"]
|
||||
try:
|
||||
with torch.cuda.stream(self.recv_stream):
|
||||
tensor = torch.empty(data["shape"],
|
||||
dtype=getattr(
|
||||
torch, data["dtype"]),
|
||||
device=self.device)
|
||||
self.router_socket.send_multipart([remote_address, b"0"])
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.recv(comm, tensor, rank ^ 1, self.recv_stream)
|
||||
tensor_size = tensor.element_size() * tensor.numel()
|
||||
if (self.buffer_size + tensor_size
|
||||
> self.buffer_size_threshold):
|
||||
# Store Tensor in memory pool
|
||||
addr = self.pool.store_tensor(tensor)
|
||||
tensor = (addr, tensor.dtype, tensor.shape)
|
||||
logger.warning(
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
|
||||
"data:%s", self.zmq_address,
|
||||
remote_address.decode(), data)
|
||||
"🔴[PUT]Recv Tensor, Out Of Threshold, "
|
||||
"%s👈%s, data:%s, addr:%d", self.zmq_address,
|
||||
remote_address.decode(), data, addr)
|
||||
else:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self._have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype":
|
||||
str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self._have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1,
|
||||
self.send_stream)
|
||||
else:
|
||||
except torch.cuda.OutOfMemoryError:
|
||||
self.router_socket.send_multipart([remote_address, b"1"])
|
||||
tensor = None
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address, data)
|
||||
"🔴[PUT]Recv Tensor, Out Of Memory, %s👈%s, "
|
||||
"data:%s", self.zmq_address, remote_address.decode(),
|
||||
data)
|
||||
|
||||
def _have_sent_tensor_id(self, tensor_id: str):
|
||||
with self.recv_store_cv:
|
||||
self.recv_store[tensor_id] = tensor
|
||||
self.have_received_tensor_id(tensor_id)
|
||||
self.recv_store_cv.notify()
|
||||
|
||||
elif data["cmd"] == "GET":
|
||||
tensor_id = data["tensor_id"]
|
||||
with self.send_store_cv:
|
||||
tensor = self.send_store.pop(tensor_id, None)
|
||||
if tensor is not None:
|
||||
data = {
|
||||
"ret": 0,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
# LRU
|
||||
self.send_store[tensor_id] = tensor
|
||||
self.have_sent_tensor_id(tensor_id)
|
||||
else:
|
||||
data = {"ret": 1}
|
||||
|
||||
self.router_socket.send_multipart(
|
||||
[remote_address, msgpack.dumps(data)])
|
||||
|
||||
if data["ret"] == 0:
|
||||
comm, rank = self.comms[remote_address.decode()]
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1,
|
||||
self.send_stream)
|
||||
else:
|
||||
logger.warning(
|
||||
"🚧Unexpected, Received message from %s, data:%s",
|
||||
remote_address, data)
|
||||
|
||||
def have_sent_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.send_request_id_to_tensor_ids:
|
||||
self.send_request_id_to_tensor_ids[request_id] = set()
|
||||
self.send_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def _have_received_tensor_id(self, tensor_id: str):
|
||||
def have_received_tensor_id(self, tensor_id: str):
|
||||
request_id = tensor_id.split('#')[0]
|
||||
if request_id not in self.recv_request_id_to_tensor_ids:
|
||||
self.recv_request_id_to_tensor_ids[request_id] = set()
|
||||
self.recv_request_id_to_tensor_ids[request_id].add(tensor_id)
|
||||
|
||||
def _send_async(self):
|
||||
def send_async(self):
|
||||
while True:
|
||||
with self.send_queue_cv:
|
||||
while not self.send_queue:
|
||||
self.send_queue_cv.wait()
|
||||
tensor_id, remote_address, tensor = self.send_queue.popleft()
|
||||
item = self.send_queue.popleft()
|
||||
if not self.send_queue:
|
||||
self.send_queue_cv.notify()
|
||||
self._send_sync(tensor_id, tensor, remote_address)
|
||||
self.send_sync(item)
|
||||
|
||||
def wait_for_sent(self):
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
@@ -409,22 +427,21 @@ class P2pNcclEngine:
|
||||
"🚧[PUT_ASYNC]It took %.3fms to wait for the send_queue"
|
||||
" to be empty, rank:%d", duration * 1000, self.rank)
|
||||
|
||||
def _send_sync(
|
||||
self,
|
||||
tensor_id: str,
|
||||
tensor: torch.Tensor,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> bool:
|
||||
if remote_address is None:
|
||||
def send_sync(self, item: SendQueueItem) -> bool:
|
||||
if item.remote_address is None:
|
||||
return False
|
||||
if remote_address not in self.socks:
|
||||
self._create_connect(remote_address)
|
||||
if item.remote_address not in self.socks:
|
||||
self.create_connect(item.remote_address)
|
||||
|
||||
sock = self.socks[remote_address]
|
||||
comm, rank = self.comms[remote_address]
|
||||
with self.send_stream:
|
||||
tensor = self.extract_kv_from_layer(item.is_mla, item.tensor,
|
||||
item.slot_mapping)
|
||||
|
||||
sock = self.socks[item.remote_address]
|
||||
comm, rank = self.comms[item.remote_address]
|
||||
data = {
|
||||
"cmd": "PUT",
|
||||
"tensor_id": tensor_id,
|
||||
"tensor_id": item.tensor_id,
|
||||
"shape": tensor.shape,
|
||||
"dtype": str(tensor.dtype).replace("torch.", "")
|
||||
}
|
||||
@@ -435,20 +452,21 @@ class P2pNcclEngine:
|
||||
logger.error(
|
||||
"🔴Send Tensor, Peer Out Of Memory/Threshold, %s 👉 %s, "
|
||||
"MyRank:%s, data:%s, tensor:%s, size:%fGB, response:%s",
|
||||
self.zmq_address, remote_address, rank, data, tensor.shape,
|
||||
self.zmq_address, item.remote_address, rank, data,
|
||||
tensor.shape,
|
||||
tensor.element_size() * tensor.numel() / 1024**3,
|
||||
response.decode())
|
||||
return False
|
||||
|
||||
self._send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
self.send(comm, tensor.to(self.device), rank ^ 1, self.send_stream)
|
||||
|
||||
if self.send_type == "PUT_ASYNC":
|
||||
self._have_sent_tensor_id(tensor_id)
|
||||
self.have_sent_tensor_id(item.tensor_id)
|
||||
|
||||
return True
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str], forward_context: "ForwardContext"
|
||||
self, finished_req_ids: set[str], no_compile_layers
|
||||
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
@@ -463,7 +481,7 @@ class P2pNcclEngine:
|
||||
|
||||
# Clear the buffer upon request completion.
|
||||
for request_id in finished_req_ids:
|
||||
for layer_name in forward_context.no_compile_layers:
|
||||
for layer_name in no_compile_layers:
|
||||
tensor_id = request_id + "#" + layer_name
|
||||
if tensor_id in self.recv_store:
|
||||
with self.recv_store_cv:
|
||||
@@ -472,7 +490,6 @@ class P2pNcclEngine:
|
||||
request_id, None)
|
||||
self.recv_request_id_to_tensor_ids.pop(
|
||||
request_id, None)
|
||||
addr = 0
|
||||
if isinstance(tensor, tuple):
|
||||
addr, _, _ = tensor
|
||||
self.pool.free(addr)
|
||||
@@ -485,7 +502,7 @@ class P2pNcclEngine:
|
||||
|
||||
return finished_sending or None, finished_recving or None
|
||||
|
||||
def _ping(self):
|
||||
def ping(self):
|
||||
sock = self.context.socket(zmq.DEALER)
|
||||
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
|
||||
logger.debug("ping start, zmq_address:%s", self.zmq_address)
|
||||
@@ -499,7 +516,7 @@ class P2pNcclEngine:
|
||||
sock.send(msgpack.dumps(data))
|
||||
time.sleep(3)
|
||||
|
||||
def _send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
def send(self, comm, tensor: torch.Tensor, dst: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
@@ -512,7 +529,7 @@ class P2pNcclEngine:
|
||||
comm, cudaStream_t(stream.cuda_stream))
|
||||
stream.synchronize()
|
||||
|
||||
def _recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
def recv(self, comm, tensor: torch.Tensor, src: int, stream=None):
|
||||
assert tensor.device == self.device, (
|
||||
f"this nccl communicator is created to work on {self.device}, "
|
||||
f"but the input tensor is on {tensor.device}")
|
||||
@@ -531,3 +548,21 @@ class P2pNcclEngine:
|
||||
self._send_thread.join()
|
||||
if self._ping_thread is not None:
|
||||
self._ping_thread.join()
|
||||
|
||||
@staticmethod
|
||||
def extract_kv_from_layer(
|
||||
is_mla: bool,
|
||||
layer: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Extract the KV cache from the layer.
|
||||
Assume the shape of the layer is (2, num_pages, page_size, xxx)
|
||||
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
|
||||
"""
|
||||
if is_mla:
|
||||
num_pages, page_size = layer.shape[0], layer.shape[1]
|
||||
return layer.reshape(num_pages * page_size, -1)[slot_mapping, ...]
|
||||
|
||||
num_pages, page_size = layer.shape[1], layer.shape[2]
|
||||
return layer.reshape(2, num_pages * page_size, -1)[:, slot_mapping,
|
||||
...]
|
||||
|
||||
Reference in New Issue
Block a user