[v1] torchrun compatibility (#13642)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -3,6 +3,9 @@
|
||||
from concurrent.futures import Future
|
||||
from typing import List, Type, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
@@ -49,12 +52,14 @@ class Executor(ExecutorBase):
|
||||
f"{distributed_executor_backend}")
|
||||
return executor_class
|
||||
|
||||
def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None:
|
||||
def initialize_from_config(self,
|
||||
kv_cache_configs: List[KVCacheConfig]) -> None:
|
||||
"""
|
||||
Initialize the KV caches and begin the model execution loop of the
|
||||
underlying workers.
|
||||
"""
|
||||
self.collective_rpc("initialize_cache", args=(kv_cache_configs, ))
|
||||
self.collective_rpc("initialize_from_config",
|
||||
args=(kv_cache_configs, ))
|
||||
self.collective_rpc("compile_or_warm_up_model")
|
||||
|
||||
def determine_available_memory(self) -> int: # in bytes
|
||||
@@ -89,4 +94,13 @@ class UniProcExecutor(UniProcExecutorV0, Executor):
|
||||
|
||||
|
||||
class ExecutorWithExternalLauncher(ExecutorWithExternalLauncherV0, Executor):
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> int: # in bytes
|
||||
# same as determine_num_available_blocks in v0,
|
||||
# we need to get the min across all ranks.
|
||||
memory = super().determine_available_memory()
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
cpu_group = get_world_group().cpu_group
|
||||
memory_tensor = torch.tensor([memory], device="cpu", dtype=torch.int64)
|
||||
dist.all_reduce(memory_tensor, group=cpu_group, op=dist.ReduceOp.MIN)
|
||||
return memory_tensor.item()
|
||||
|
||||
@@ -216,9 +216,10 @@ class WorkerProc:
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
"distributed_init_method": distributed_init_method,
|
||||
"is_driver_worker": rank == 0,
|
||||
}
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper.worker
|
||||
self.worker = wrapper
|
||||
|
||||
pid = os.getpid()
|
||||
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
|
||||
@@ -239,7 +240,7 @@ class WorkerProc:
|
||||
ready_socket.send_string(WorkerProc.READY_STR)
|
||||
ready_socket.send(payload)
|
||||
|
||||
wrapper.init_device()
|
||||
self.worker.init_device()
|
||||
self.worker.load_model()
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user