[v1] torchrun compatibility (#13642)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-02-23 22:47:24 +08:00
committed by GitHub
parent 9bebc9512f
commit eb24dc4a45
14 changed files with 67 additions and 24 deletions

View File

@@ -3,6 +3,9 @@
from concurrent.futures import Future
from typing import List, Type, Union
import torch
import torch.distributed as dist
from vllm.config import VllmConfig
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.uniproc_executor import ( # noqa
@@ -49,12 +52,14 @@ class Executor(ExecutorBase):
f"{distributed_executor_backend}")
return executor_class
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
def initialize_from_config(self,
kv_cache_configs: List[KVCacheConfig]) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
self.collective_rpc("initialize_from_config",
args=(kv_cache_configs, ))
self.collective_rpc("compile_or_warm_up_model")
def determine_available_memory(self) -> int: # in bytes
@@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
pass
def determine_available_memory(self) -> int: # in bytes
# same as determine_num_available_blocks in v0,
# we need to get the min across all ranks.
memory = super().determine_available_memory()
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
return memory_tensor.item()

View File

@@ -216,9 +216,10 @@ class WorkerProc:
"local_rank": local_rank,
"rank": rank,
"distributed_init_method": distributed_init_method,
"is_driver_worker": rank == 0,
}
wrapper.init_worker(all_kwargs)
self.worker = wrapper.worker
self.worker = wrapper
pid = os.getpid()
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
@@ -239,7 +240,7 @@ class WorkerProc:
ready_socket.send_string(WorkerProc.READY_STR)
ready_socket.send(payload)
wrapper.init_device()
self.worker.init_device()
self.worker.load_model()
@staticmethod