[core][distributed] use tcp store directly (#10275)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import deque
|
||||
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.distributed.rendezvous import rendezvous
|
||||
from torch.distributed import TCPStore
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
@@ -97,7 +97,6 @@ class StatelessProcessGroup:
|
||||
group. Only use it to communicate metadata between processes.
|
||||
For data-plane communication, create NCCL-related objects.
|
||||
"""
|
||||
prefix: str
|
||||
rank: int
|
||||
world_size: int
|
||||
store: torch._C._distributed_c10d.Store
|
||||
@@ -127,7 +126,7 @@ class StatelessProcessGroup:
|
||||
def send_obj(self, obj: Any, dst: int):
|
||||
"""Send an object to a destination rank."""
|
||||
self.expire_data()
|
||||
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.send_dst_counter[dst] += 1
|
||||
self.entries.append((key, time.time()))
|
||||
@@ -147,8 +146,7 @@ class StatelessProcessGroup:
|
||||
"""Receive an object from a source rank."""
|
||||
obj = pickle.loads(
|
||||
self.store.get(
|
||||
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
|
||||
))
|
||||
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
|
||||
self.recv_src_counter[src] += 1
|
||||
return obj
|
||||
|
||||
@@ -159,14 +157,14 @@ class StatelessProcessGroup:
|
||||
"""
|
||||
if self.rank == src:
|
||||
self.expire_data()
|
||||
key = (f"{self.prefix}/broadcast_from/{src}/"
|
||||
key = (f"broadcast_from/{src}/"
|
||||
f"{self.broadcast_send_counter}")
|
||||
self.store.set(key, pickle.dumps(obj))
|
||||
self.broadcast_send_counter += 1
|
||||
self.entries.append((key, time.time()))
|
||||
return obj
|
||||
else:
|
||||
key = (f"{self.prefix}/broadcast_from/{src}/"
|
||||
key = (f"broadcast_from/{src}/"
|
||||
f"{self.broadcast_recv_src_counter[src]}")
|
||||
recv_obj = pickle.loads(self.store.get(key))
|
||||
self.broadcast_recv_src_counter[src] += 1
|
||||
@@ -194,7 +192,8 @@ class StatelessProcessGroup:
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
init_method: str,
|
||||
host: str,
|
||||
port: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
@@ -214,15 +213,14 @@ class StatelessProcessGroup:
|
||||
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
||||
C, and D can call `StatelessProcessGroup.create` to form another group.
|
||||
""" # noqa
|
||||
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
|
||||
timeout = _DEFAULT_PG_TIMEOUT
|
||||
|
||||
store, rank, world_size = next(
|
||||
rendezvous(init_method, rank, world_size, timeout=timeout))
|
||||
store.set_timeout(timeout)
|
||||
store = TCPStore(
|
||||
host_name=host,
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=(rank == 0),
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
prefix=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
store=store,
|
||||
|
||||
Reference in New Issue
Block a user