[core][distributed] use tcp store directly (#10275)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-12 17:36:08 -08:00
committed by GitHub
parent 112fa0bbe5
commit 0d4ea3fb5c
2 changed files with 29 additions and 25 deletions

View File

@@ -9,7 +9,7 @@ from collections import deque
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
import torch
from torch.distributed.rendezvous import rendezvous
from torch.distributed import TCPStore
import vllm.envs as envs
from vllm.logger import init_logger
@@ -97,7 +97,6 @@ class StatelessProcessGroup:
group. Only use it to communicate metadata between processes.
For data-plane communication, create NCCL-related objects.
"""
prefix: str
rank: int
world_size: int
store: torch._C._distributed_c10d.Store
@@ -127,7 +126,7 @@ class StatelessProcessGroup:
def send_obj(self, obj: Any, dst: int):
"""Send an object to a destination rank."""
self.expire_data()
key = f"{self.prefix}/send_to/{dst}/{self.send_dst_counter[dst]}"
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
self.store.set(key, pickle.dumps(obj))
self.send_dst_counter[dst] += 1
self.entries.append((key, time.time()))
@@ -147,8 +146,7 @@ class StatelessProcessGroup:
"""Receive an object from a source rank."""
obj = pickle.loads(
self.store.get(
f"{self.prefix}/send_to/{self.rank}/{self.recv_src_counter[src]}"
))
f"send_to/{self.rank}/{self.recv_src_counter[src]}"))
self.recv_src_counter[src] += 1
return obj
@@ -159,14 +157,14 @@ class StatelessProcessGroup:
"""
if self.rank == src:
self.expire_data()
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_send_counter}")
self.store.set(key, pickle.dumps(obj))
self.broadcast_send_counter += 1
self.entries.append((key, time.time()))
return obj
else:
key = (f"{self.prefix}/broadcast_from/{src}/"
key = (f"broadcast_from/{src}/"
f"{self.broadcast_recv_src_counter[src]}")
recv_obj = pickle.loads(self.store.get(key))
self.broadcast_recv_src_counter[src] += 1
@@ -194,7 +192,8 @@ class StatelessProcessGroup:
@staticmethod
def create(
init_method: str,
host: str,
port: int,
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
@@ -214,15 +213,14 @@ class StatelessProcessGroup:
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
C, and D can call `StatelessProcessGroup.create` to form another group.
""" # noqa
from torch._C._distributed_c10d import _DEFAULT_PG_TIMEOUT
timeout = _DEFAULT_PG_TIMEOUT
store, rank, world_size = next(
rendezvous(init_method, rank, world_size, timeout=timeout))
store.set_timeout(timeout)
store = TCPStore(
host_name=host,
port=port,
world_size=world_size,
is_master=(rank == 0),
)
return StatelessProcessGroup(
prefix=init_method,
rank=rank,
world_size=world_size,
store=store,