[V1] Support MP Executor for multi node distributed inference (#23691)
Signed-off-by: Lu Fang <fanglu@fb.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Signed-off-by: Lucia Fang <fanglu@fb.com> Signed-off-by: Lucia Fang <116399278+luccafong@users.noreply.github.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -210,6 +210,18 @@ class ParallelConfig:
|
||||
class is dynamically inherited by the worker class. This is used to inject
|
||||
new attributes and methods to the worker class for use in collective_rpc
|
||||
calls."""
|
||||
master_addr: str = "127.0.0.1"
|
||||
"""distributed master address for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
master_port: int = 29501
|
||||
"""distributed master port for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
node_rank: int = 0
|
||||
"""distributed node rank for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
nnodes: int = 1
|
||||
"""num of nodes for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
|
||||
world_size: int = Field(init=False)
|
||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||
@@ -387,6 +399,23 @@ class ParallelConfig:
|
||||
and self.data_parallel_size > 1
|
||||
)
|
||||
|
||||
@property
|
||||
def node_rank_within_dp(self) -> int:
|
||||
return self.node_rank % self.nnodes_within_dp
|
||||
|
||||
@property
|
||||
def nnodes_within_dp(self) -> int:
|
||||
if self.nnodes == 1:
|
||||
return 1
|
||||
data_parallel_node_size = (
|
||||
self.data_parallel_size // self.data_parallel_size_local
|
||||
)
|
||||
return self.nnodes // data_parallel_node_size
|
||||
|
||||
@property
|
||||
def local_world_size(self) -> int:
|
||||
return self.world_size // self.nnodes_within_dp
|
||||
|
||||
@staticmethod
|
||||
def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool:
|
||||
tensor = torch.tensor([has_unfinished], dtype=torch.int32, device="cpu")
|
||||
@@ -528,6 +557,8 @@ class ParallelConfig:
|
||||
ray_found = ray_utils.ray_is_available()
|
||||
if current_platform.is_tpu() and envs.VLLM_XLA_USE_SPMD:
|
||||
backend = "uni"
|
||||
elif current_platform.is_cuda() and self.nnodes > 1:
|
||||
backend = "mp"
|
||||
elif (
|
||||
current_platform.is_cuda()
|
||||
and cuda_device_count_stateless() < self.world_size
|
||||
@@ -565,6 +596,10 @@ class ParallelConfig:
|
||||
"max_parallel_loading_workers is currently "
|
||||
"not supported and will be ignored."
|
||||
)
|
||||
if self.distributed_executor_backend != "mp" and self.nnodes > 1:
|
||||
raise ValueError(
|
||||
"nnodes > 1 can only be set when distributed exectuor backend is mp."
|
||||
)
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
@@ -607,6 +642,11 @@ class ParallelConfig:
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on current platform."
|
||||
)
|
||||
if self.nnodes > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.debug(
|
||||
"Disabled the custom all-reduce since we are running on multi-node."
|
||||
)
|
||||
if self.ray_workers_use_nsight and not self.use_ray:
|
||||
raise ValueError(
|
||||
"Unable to use nsight profiling unless workers run with Ray."
|
||||
|
||||
@@ -8,7 +8,7 @@ from dataclasses import dataclass, field
|
||||
from multiprocessing import shared_memory
|
||||
from pickle import PickleBuffer
|
||||
from threading import Event
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@@ -602,13 +602,87 @@ class MessageQueue:
|
||||
return obj
|
||||
return self.dequeue()
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group_single_reader(
|
||||
pg: ProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
reader_rank: int = 0,
|
||||
blocking: bool = False,
|
||||
) -> tuple["MessageQueue", list[Handle]]:
|
||||
"""
|
||||
Creates a MessageQueue for a process group with a single reader.
|
||||
|
||||
This method is designed for scenarios where only one process (the reader)
|
||||
will consume messages, and all other processes are writers. It sets up
|
||||
the shared memory buffer and communication handles accordingly, and
|
||||
gathers the handles from all processes to the reader.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): The torch distributed process group.
|
||||
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||
max_chunks (int): Maximum number of chunks in the buffer.
|
||||
reader_rank (int, optional): The global rank that will act as the reader.
|
||||
Defaults to 0.
|
||||
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
tuple[MessageQueue, list[Handle]]:
|
||||
The MessageQueue instance for the calling process,
|
||||
and a list of handles (only non-empty for the reader process).
|
||||
"""
|
||||
local_size = torch.cuda.device_count()
|
||||
rank = dist.get_rank()
|
||||
same_node = rank // local_size == reader_rank // local_size
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=1,
|
||||
n_local_reader=1 if same_node else 0,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
|
||||
dist.gather_object(handle, handles, dst=reader_rank, group=pg)
|
||||
if blocking:
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io, cast(list[Handle], handles or [])
|
||||
|
||||
@staticmethod
|
||||
def create_from_process_group(
|
||||
pg: ProcessGroup | StatelessProcessGroup,
|
||||
max_chunk_bytes,
|
||||
max_chunks,
|
||||
writer_rank=0,
|
||||
writer_rank: int = 0,
|
||||
external_writer_handle=None,
|
||||
blocking: bool = True,
|
||||
) -> "MessageQueue":
|
||||
"""
|
||||
Creates a MessageQueue for a distributed process group with one writer and
|
||||
multiple readers.
|
||||
|
||||
This method is designed for scenarios where one process (the writer) sends
|
||||
messages, and all other processes (the readers) receive messages. It sets up
|
||||
the shared memory buffer and socket communication handles accordingly, and
|
||||
broadcasts the handle from the writer to all readers.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
|
||||
group.
|
||||
max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
|
||||
max_chunks (int): Maximum number of chunks in the buffer.
|
||||
writer_rank (int, optional): The global rank that will act as the writer.
|
||||
Defaults to 0.
|
||||
external_writer_handle (Handle, optional): Used when there is a handle
|
||||
from an external Message Queue. If provided, use this handle to init
|
||||
PG writer message queue instead of creating a new one. Defaults to None.
|
||||
blocking (bool, optional): If True, blocks until all processes are ready.
|
||||
Defaults to True.
|
||||
|
||||
Returns:
|
||||
MessageQueue: The MessageQueue instance for the calling process.
|
||||
|
||||
"""
|
||||
if isinstance(pg, ProcessGroup):
|
||||
group_rank = dist.get_rank(pg)
|
||||
group_world_size = dist.get_world_size(pg)
|
||||
@@ -617,23 +691,26 @@ class MessageQueue:
|
||||
group_rank = pg.rank
|
||||
group_world_size = pg.world_size
|
||||
global_ranks = list(range(pg.world_size))
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
|
||||
status = in_the_same_node_as(pg, source_rank=writer_rank)
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io: MessageQueue
|
||||
if group_rank == writer_rank:
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
if external_writer_handle is not None:
|
||||
buffer_io = MessageQueue.create_from_handle(
|
||||
external_writer_handle, group_rank
|
||||
)
|
||||
else:
|
||||
same_node_ranks = [i for i, s in enumerate(status) if s]
|
||||
n_reader = group_world_size - 1
|
||||
n_local_reader = len(same_node_ranks) - 1
|
||||
local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
|
||||
buffer_io = MessageQueue(
|
||||
n_reader=n_reader,
|
||||
n_local_reader=n_local_reader,
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
max_chunks=max_chunks,
|
||||
)
|
||||
handle = buffer_io.export_handle()
|
||||
if isinstance(pg, ProcessGroup):
|
||||
dist.broadcast_object_list(
|
||||
@@ -651,5 +728,6 @@ class MessageQueue:
|
||||
else:
|
||||
handle = pg.broadcast_obj(None, writer_rank)
|
||||
buffer_io = MessageQueue.create_from_handle(handle, group_rank)
|
||||
buffer_io.wait_until_ready()
|
||||
if blocking:
|
||||
buffer_io.wait_until_ready()
|
||||
return buffer_io
|
||||
|
||||
@@ -385,6 +385,33 @@ class GroupCoordinator:
|
||||
torch.ops._C, "init_shm_manager"
|
||||
)
|
||||
|
||||
def create_mq_broadcaster(
|
||||
self, writer_rank=0, external_writer_handle=None, blocking=True
|
||||
):
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
|
||||
return MessageQueue.create_from_process_group(
|
||||
self.cpu_group,
|
||||
1 << 22,
|
||||
6,
|
||||
writer_rank=writer_rank,
|
||||
external_writer_handle=external_writer_handle,
|
||||
blocking=blocking,
|
||||
)
|
||||
|
||||
def create_single_reader_mq_broadcasters(
|
||||
self, reader_rank_in_group=0, blocking=False
|
||||
):
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
|
||||
return MessageQueue.create_from_process_group_single_reader(
|
||||
self.cpu_group,
|
||||
1 << 22,
|
||||
6,
|
||||
reader_rank=self.ranks[reader_rank_in_group],
|
||||
blocking=blocking,
|
||||
)
|
||||
|
||||
@property
|
||||
def first_rank(self):
|
||||
"""Return the global rank of the first process in the group"""
|
||||
@@ -997,6 +1024,7 @@ class GroupCoordinator:
|
||||
|
||||
|
||||
_WORLD: GroupCoordinator | None = None
|
||||
_INNER_DP_WORLD: GroupCoordinator | None = None
|
||||
_NODE_COUNT: int | None = None
|
||||
|
||||
|
||||
@@ -1005,6 +1033,11 @@ def get_world_group() -> GroupCoordinator:
|
||||
return _WORLD
|
||||
|
||||
|
||||
def get_inner_dp_world_group() -> GroupCoordinator:
|
||||
assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
|
||||
return _INNER_DP_WORLD
|
||||
|
||||
|
||||
def init_world_group(
|
||||
ranks: list[int], local_rank: int, backend: str
|
||||
) -> GroupCoordinator:
|
||||
@@ -1023,12 +1056,13 @@ def init_model_parallel_group(
|
||||
backend: str,
|
||||
use_message_queue_broadcaster: bool = False,
|
||||
group_name: str | None = None,
|
||||
use_device_communicator: bool = True,
|
||||
) -> GroupCoordinator:
|
||||
return GroupCoordinator(
|
||||
group_ranks=group_ranks,
|
||||
local_rank=local_rank,
|
||||
torch_distributed_backend=backend,
|
||||
use_device_communicator=True,
|
||||
use_device_communicator=use_device_communicator,
|
||||
use_message_queue_broadcaster=use_message_queue_broadcaster,
|
||||
group_name=group_name,
|
||||
)
|
||||
@@ -1143,7 +1177,14 @@ def init_distributed_environment(
|
||||
from vllm.config import get_current_vllm_config
|
||||
|
||||
config = get_current_vllm_config()
|
||||
if (
|
||||
if config is not None and config.parallel_config.nnodes > 1:
|
||||
parallel_config = config.parallel_config
|
||||
ip = parallel_config.master_addr
|
||||
rank = parallel_config.data_parallel_rank * world_size + rank
|
||||
world_size = parallel_config.world_size_across_dp
|
||||
port = parallel_config.master_port
|
||||
distributed_init_method = get_distributed_init_method(ip, port)
|
||||
elif (
|
||||
config is not None
|
||||
and config.parallel_config.data_parallel_size > 1
|
||||
and config.parallel_config.distributed_executor_backend != "external_launcher"
|
||||
@@ -1164,6 +1205,14 @@ def init_distributed_environment(
|
||||
distributed_init_method,
|
||||
)
|
||||
if not torch.distributed.is_initialized():
|
||||
logger.info(
|
||||
"world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
|
||||
world_size,
|
||||
rank,
|
||||
local_rank,
|
||||
distributed_init_method,
|
||||
backend,
|
||||
)
|
||||
assert distributed_init_method is not None, (
|
||||
"distributed_init_method must be provided when initializing "
|
||||
"distributed environment"
|
||||
@@ -1192,16 +1241,36 @@ def init_distributed_environment(
|
||||
# local rank not set, this usually happens in single-node
|
||||
# setting, where we can use rank as local rank
|
||||
local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
|
||||
global _WORLD, _NODE_COUNT
|
||||
global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
|
||||
if _WORLD is None:
|
||||
ranks = list(range(torch.distributed.get_world_size()))
|
||||
_WORLD = init_world_group(ranks, local_rank, backend)
|
||||
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||
if config.parallel_config.nnodes > 1:
|
||||
_NODE_COUNT = config.parallel_config.nnodes
|
||||
else:
|
||||
_NODE_COUNT = _node_count(_WORLD.cpu_group)
|
||||
logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
|
||||
else:
|
||||
assert _WORLD.world_size == torch.distributed.get_world_size(), (
|
||||
"world group already initialized with a different world size"
|
||||
)
|
||||
if config.parallel_config.nnodes_within_dp > 1:
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
world_size_inner_dp = parallel_config.world_size
|
||||
group_ranks = [
|
||||
[dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
|
||||
for dp_rank in range(parallel_config.data_parallel_size)
|
||||
]
|
||||
_INNER_DP_WORLD = init_model_parallel_group(
|
||||
group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
use_message_queue_broadcaster=True,
|
||||
group_name="inner_dp_world",
|
||||
use_device_communicator=False,
|
||||
)
|
||||
else:
|
||||
_INNER_DP_WORLD = _WORLD
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
|
||||
@@ -384,6 +384,10 @@ class EngineArgs:
|
||||
) = ParallelConfig.distributed_executor_backend
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
master_addr: str = ParallelConfig.master_addr
|
||||
master_port: int = ParallelConfig.master_port
|
||||
nnodes: int = ParallelConfig.nnodes
|
||||
node_rank: int = ParallelConfig.node_rank
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||
dcp_kv_cache_interleave_size: int = ParallelConfig.dcp_kv_cache_interleave_size
|
||||
@@ -394,6 +398,7 @@ class EngineArgs:
|
||||
data_parallel_address: str | None = None
|
||||
data_parallel_rpc_port: int | None = None
|
||||
data_parallel_hybrid_lb: bool = False
|
||||
data_parallel_external_lb: bool = False
|
||||
data_parallel_backend: str = ParallelConfig.data_parallel_backend
|
||||
enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel
|
||||
all2all_backend: str | None = ParallelConfig.all2all_backend
|
||||
@@ -749,6 +754,10 @@ class EngineArgs:
|
||||
"-pp",
|
||||
**parallel_kwargs["pipeline_parallel_size"],
|
||||
)
|
||||
parallel_group.add_argument("--master-addr", **parallel_kwargs["master_addr"])
|
||||
parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
|
||||
parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
|
||||
parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
|
||||
parallel_group.add_argument(
|
||||
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
|
||||
)
|
||||
@@ -803,7 +812,14 @@ class EngineArgs:
|
||||
help='Backend for data parallel, either "mp" or "ray".',
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"]
|
||||
"--data-parallel-hybrid-lb",
|
||||
"-dph",
|
||||
**parallel_kwargs["data_parallel_hybrid_lb"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--data-parallel-external-lb",
|
||||
"-dpe",
|
||||
**parallel_kwargs["data_parallel_external_lb"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"]
|
||||
@@ -1428,12 +1444,56 @@ class EngineArgs:
|
||||
assert not headless or not self.data_parallel_hybrid_lb, (
|
||||
"data_parallel_hybrid_lb is not applicable in headless mode"
|
||||
)
|
||||
|
||||
data_parallel_external_lb = self.data_parallel_rank is not None
|
||||
assert not (self.data_parallel_hybrid_lb and self.data_parallel_external_lb), (
|
||||
"data_parallel_hybrid_lb and data_parallel_external_lb cannot both be True."
|
||||
)
|
||||
assert self.data_parallel_backend == "mp" or self.nnodes == 1, (
|
||||
"nnodes > 1 is only supported with data_parallel_backend=mp"
|
||||
)
|
||||
inferred_data_parallel_rank = 0
|
||||
if self.nnodes > 1:
|
||||
world_size = (
|
||||
self.data_parallel_size
|
||||
* self.pipeline_parallel_size
|
||||
* self.tensor_parallel_size
|
||||
)
|
||||
world_size_within_dp = (
|
||||
self.pipeline_parallel_size * self.tensor_parallel_size
|
||||
)
|
||||
local_world_size = world_size // self.nnodes
|
||||
assert world_size % self.nnodes == 0, (
|
||||
f"world_size={world_size} must be divisible by nnodes={self.nnodes}."
|
||||
)
|
||||
assert self.node_rank < self.nnodes, (
|
||||
f"node_rank={self.node_rank} must be less than nnodes={self.nnodes}."
|
||||
)
|
||||
inferred_data_parallel_rank = (
|
||||
self.node_rank * local_world_size
|
||||
) // world_size_within_dp
|
||||
if self.data_parallel_size > 1 and self.data_parallel_external_lb:
|
||||
self.data_parallel_rank = inferred_data_parallel_rank
|
||||
logger.info(
|
||||
"Inferred data_parallel_rank %d from node_rank %d for external lb",
|
||||
self.data_parallel_rank,
|
||||
self.node_rank,
|
||||
)
|
||||
elif self.data_parallel_size_local is None:
|
||||
# Infer data parallel size local for internal dplb:
|
||||
self.data_parallel_size_local = max(
|
||||
local_world_size // world_size_within_dp, 1
|
||||
)
|
||||
data_parallel_external_lb = (
|
||||
self.data_parallel_external_lb or self.data_parallel_rank is not None
|
||||
)
|
||||
# Local DP rank = 1, use pure-external LB.
|
||||
if data_parallel_external_lb:
|
||||
assert self.data_parallel_rank is not None, (
|
||||
"data_parallel_rank or node_rank must be spefified if "
|
||||
"data_parallel_external_lb is enable."
|
||||
)
|
||||
assert self.data_parallel_size_local in (1, None), (
|
||||
"data_parallel_size_local must be 1 when data_parallel_rank is set"
|
||||
"data_parallel_size_local must be 1 or None when data_parallel_rank "
|
||||
"is set"
|
||||
)
|
||||
data_parallel_size_local = 1
|
||||
# Use full external lb if we have local_size of 1.
|
||||
@@ -1447,6 +1507,11 @@ class EngineArgs:
|
||||
|
||||
if self.data_parallel_hybrid_lb and data_parallel_size_local == 1:
|
||||
# Use full external lb if we have local_size of 1.
|
||||
logger.warning(
|
||||
"data_parallel_hybrid_lb is not eligible when "
|
||||
"data_parallel_size_local = 1, autoswitch to "
|
||||
"data_parallel_external_lb."
|
||||
)
|
||||
data_parallel_external_lb = True
|
||||
self.data_parallel_hybrid_lb = False
|
||||
|
||||
@@ -1454,7 +1519,15 @@ class EngineArgs:
|
||||
# Disable hybrid LB mode if set for a single node
|
||||
self.data_parallel_hybrid_lb = False
|
||||
|
||||
self.data_parallel_rank = self.data_parallel_start_rank or 0
|
||||
self.data_parallel_rank = (
|
||||
self.data_parallel_start_rank or inferred_data_parallel_rank
|
||||
)
|
||||
if self.nnodes > 1:
|
||||
logger.info(
|
||||
"Inferred data_parallel_rank %d from node_rank %d",
|
||||
self.data_parallel_rank,
|
||||
self.node_rank,
|
||||
)
|
||||
else:
|
||||
assert not self.data_parallel_hybrid_lb, (
|
||||
"data_parallel_size_local must be set to use data_parallel_hybrid_lb."
|
||||
@@ -1484,7 +1557,9 @@ class EngineArgs:
|
||||
"data_parallel_backend can only be ray or mp, got %s",
|
||||
self.data_parallel_backend,
|
||||
)
|
||||
data_parallel_address = ParallelConfig.data_parallel_master_ip
|
||||
data_parallel_address = (
|
||||
self.master_addr or ParallelConfig.data_parallel_master_ip
|
||||
)
|
||||
else:
|
||||
data_parallel_address = self.data_parallel_address
|
||||
|
||||
@@ -1517,6 +1592,10 @@ class EngineArgs:
|
||||
data_parallel_rank=self.data_parallel_rank or 0,
|
||||
data_parallel_external_lb=data_parallel_external_lb,
|
||||
data_parallel_size_local=data_parallel_size_local,
|
||||
master_addr=self.master_addr,
|
||||
master_port=self.master_port,
|
||||
nnodes=self.nnodes,
|
||||
node_rank=self.node_rank,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
from vllm.v1.metrics.prometheus import setup_multiprocess_prometheus
|
||||
from vllm.v1.utils import APIServerProcessManager, wait_for_completion_or_failure
|
||||
|
||||
@@ -97,18 +98,40 @@ def run_headless(args: argparse.Namespace):
|
||||
if local_engine_count <= 0:
|
||||
raise ValueError("data_parallel_size_local must be > 0 in headless mode")
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = engine_args.data_parallel_rpc_port # add to config too
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
shutdown_requested = False
|
||||
|
||||
# Catch SIGTERM and SIGINT to allow graceful shutdown.
|
||||
def signal_handler(signum, frame):
|
||||
nonlocal shutdown_requested
|
||||
logger.debug("Received %d signal.", signum)
|
||||
raise SystemExit
|
||||
if not shutdown_requested:
|
||||
shutdown_requested = True
|
||||
raise SystemExit
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
if parallel_config.node_rank_within_dp > 0:
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
# Run headless workers (for multi-node PP/TP).
|
||||
host = parallel_config.master_addr
|
||||
head_node_address = f"{host}:{parallel_config.master_port}"
|
||||
logger.info(
|
||||
"Launching vLLM (v%s) headless multiproc executor, "
|
||||
"with head node address %s for torch.distributed process group.",
|
||||
VLLM_VERSION,
|
||||
head_node_address,
|
||||
)
|
||||
|
||||
executor = MultiprocExecutor(vllm_config, monitor_workers=False)
|
||||
executor.start_worker_monitor(inline=True)
|
||||
return
|
||||
|
||||
host = parallel_config.data_parallel_master_ip
|
||||
port = parallel_config.data_parallel_rpc_port
|
||||
handshake_address = get_tcp_uri(host, port)
|
||||
|
||||
logger.info(
|
||||
"Launching %d data parallel engine(s) in headless mode, "
|
||||
"with head node address %s.",
|
||||
|
||||
@@ -183,15 +183,19 @@ def set_device_control_env_var(
|
||||
for engine subprocess.
|
||||
"""
|
||||
world_size = vllm_config.parallel_config.world_size
|
||||
local_world_size = vllm_config.parallel_config.local_world_size
|
||||
evar = current_platform.device_control_env_var
|
||||
|
||||
value = get_device_indices(evar, local_dp_rank, world_size)
|
||||
value = get_device_indices(evar, local_dp_rank, world_size, local_world_size)
|
||||
with patch.dict(os.environ, values=((evar, value),)):
|
||||
yield
|
||||
|
||||
|
||||
def get_device_indices(
|
||||
device_control_env_var: str, local_dp_rank: int, world_size: int
|
||||
device_control_env_var: str,
|
||||
local_dp_rank: int,
|
||||
world_size: int,
|
||||
local_world_size: int | None = None,
|
||||
):
|
||||
"""
|
||||
Returns a comma-separated string of device indices for the specified
|
||||
@@ -200,10 +204,15 @@ def get_device_indices(
|
||||
For example, if world_size=2 and local_dp_rank=1, and there are 4 devices,
|
||||
this will select devices 2 and 3 for local_dp_rank=1.
|
||||
"""
|
||||
if local_world_size is None:
|
||||
local_world_size = world_size
|
||||
try:
|
||||
value = ",".join(
|
||||
str(current_platform.device_id_to_physical_device_id(i))
|
||||
for i in range(local_dp_rank * world_size, (local_dp_rank + 1) * world_size)
|
||||
for i in range(
|
||||
local_dp_rank * world_size,
|
||||
local_dp_rank * world_size + local_world_size,
|
||||
)
|
||||
)
|
||||
except IndexError as e:
|
||||
raise Exception(
|
||||
|
||||
@@ -10,7 +10,7 @@ import time
|
||||
import traceback
|
||||
import weakref
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from concurrent.futures import Future, InvalidStateError
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
@@ -34,6 +34,7 @@ from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_inner_dp_world_group,
|
||||
get_pp_group,
|
||||
get_tp_group,
|
||||
)
|
||||
@@ -90,6 +91,10 @@ class FutureWrapper(Future):
|
||||
class MultiprocExecutor(Executor):
|
||||
supports_pp: bool = True
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, monitor_workers: bool = True):
|
||||
self.monitor_workers = monitor_workers
|
||||
super().__init__(vllm_config)
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
# Call self.shutdown at exit to clean up
|
||||
# and ensure workers will be terminated.
|
||||
@@ -99,6 +104,12 @@ class MultiprocExecutor(Executor):
|
||||
self.failure_callback: FailureCallback | None = None
|
||||
|
||||
self.world_size = self.parallel_config.world_size
|
||||
assert self.world_size % self.parallel_config.nnodes_within_dp == 0, (
|
||||
f"global world_size ({self.parallel_config.world_size}) must be "
|
||||
f"divisible by nnodes_within_dp "
|
||||
f"({self.parallel_config.nnodes_within_dp}). "
|
||||
)
|
||||
self.local_world_size = self.parallel_config.local_world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
pp_parallel_size = self.parallel_config.pipeline_parallel_size
|
||||
assert self.world_size == tensor_parallel_size * pp_parallel_size, (
|
||||
@@ -116,27 +127,37 @@ class MultiprocExecutor(Executor):
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_loopback_ip(), get_open_port()
|
||||
)
|
||||
|
||||
self.rpc_broadcast_mq: MessageQueue | None = None
|
||||
scheduler_output_handle: Handle | None = None
|
||||
# Initialize worker and set up message queues for SchedulerOutputs
|
||||
# and ModelRunnerOutputs
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size, self.world_size, max_chunk_bytes=max_chunk_bytes
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
if self.parallel_config.node_rank_within_dp == 0:
|
||||
# For leader node within each dp rank,
|
||||
# each dp will have its own leader multiproc executor.
|
||||
max_chunk_bytes = envs.VLLM_MQ_MAX_CHUNK_BYTES_MB * 1024 * 1024
|
||||
self.rpc_broadcast_mq = MessageQueue(
|
||||
self.world_size,
|
||||
self.local_world_size,
|
||||
max_chunk_bytes=max_chunk_bytes,
|
||||
connect_ip=self.parallel_config.master_addr,
|
||||
)
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
# Create workers
|
||||
context = get_mp_context()
|
||||
shared_worker_lock = context.Lock()
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
for rank in range(self.world_size):
|
||||
global_start_rank = (
|
||||
self.local_world_size * self.parallel_config.node_rank_within_dp
|
||||
)
|
||||
for local_rank in range(self.local_world_size):
|
||||
global_rank = global_start_rank + local_rank
|
||||
unready_workers.append(
|
||||
WorkerProc.make_worker_process(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=rank,
|
||||
rank=rank,
|
||||
local_rank=local_rank,
|
||||
rank=global_rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
@@ -145,15 +166,38 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
# deadlock, since worker.init_device() does a device sync.
|
||||
|
||||
# Wait for all local workers to be ready.
|
||||
self.workers = WorkerProc.wait_for_ready(unready_workers)
|
||||
|
||||
# Start background thread to monitor worker health if not in headless mode.
|
||||
if self.monitor_workers:
|
||||
self.start_worker_monitor()
|
||||
|
||||
self.response_mqs = []
|
||||
# Only leader node have remote response mqs
|
||||
if self.parallel_config.node_rank_within_dp == 0:
|
||||
for rank in range(self.world_size):
|
||||
if rank < self.local_world_size:
|
||||
local_message_queue = self.workers[rank].worker_response_mq
|
||||
assert local_message_queue is not None
|
||||
self.response_mqs.append(local_message_queue)
|
||||
else:
|
||||
remote_message_queue = self.workers[0].peer_worker_response_mqs[
|
||||
rank
|
||||
]
|
||||
assert remote_message_queue is not None
|
||||
self.response_mqs.append(remote_message_queue)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered
|
||||
# Must be kept consistent with the WorkerProc.
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
for w in self.workers:
|
||||
w.worker_response_mq.wait_until_ready()
|
||||
|
||||
self.start_worker_monitor()
|
||||
# Wait for all input mqs to be ready.
|
||||
if self.rpc_broadcast_mq is not None:
|
||||
self.rpc_broadcast_mq.wait_until_ready()
|
||||
# Wait for all remote response mqs to be ready.
|
||||
for response_mq in self.response_mqs:
|
||||
response_mq.wait_until_ready()
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
@@ -168,7 +212,7 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
|
||||
def start_worker_monitor(self):
|
||||
def start_worker_monitor(self, inline=False) -> None:
|
||||
workers = self.workers
|
||||
self_ref = weakref.ref(self)
|
||||
|
||||
@@ -192,9 +236,13 @@ class MultiprocExecutor(Executor):
|
||||
_self.failure_callback = None
|
||||
callback()
|
||||
|
||||
Thread(
|
||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||
).start()
|
||||
if not inline:
|
||||
Thread(
|
||||
target=monitor_workers, daemon=True, name="MultiprocWorkerMonitor"
|
||||
).start()
|
||||
return
|
||||
|
||||
monitor_workers()
|
||||
|
||||
def register_failure_callback(self, callback: FailureCallback):
|
||||
if self.is_failed:
|
||||
@@ -247,7 +295,9 @@ class MultiprocExecutor(Executor):
|
||||
) -> Any | list[Any] | Future[Any | list[Any]]:
|
||||
"""Returns single result if unique_reply_rank and/or kv_output_aggregator
|
||||
is provided, otherwise list."""
|
||||
|
||||
assert self.rpc_broadcast_mq is not None, (
|
||||
"collective_rpc should not be called on follower node"
|
||||
)
|
||||
if self.is_failed:
|
||||
raise RuntimeError("Executor failed.")
|
||||
|
||||
@@ -269,20 +319,20 @@ class MultiprocExecutor(Executor):
|
||||
send_method = cloudpickle.dumps(method, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.rpc_broadcast_mq.enqueue((send_method, args, kwargs, output_rank))
|
||||
|
||||
workers = (
|
||||
(self.workers[output_rank],) if output_rank is not None else self.workers
|
||||
)
|
||||
response_mqs: Sequence[MessageQueue] = self.response_mqs
|
||||
if output_rank is not None:
|
||||
response_mqs = (response_mqs[output_rank],)
|
||||
|
||||
shutdown_event = self.shutdown_event
|
||||
|
||||
def get_response():
|
||||
responses = []
|
||||
for w in workers:
|
||||
for mq in response_mqs:
|
||||
dequeue_timeout = (
|
||||
None if deadline is None else (deadline - time.monotonic())
|
||||
)
|
||||
try:
|
||||
status, result = w.worker_response_mq.dequeue(
|
||||
status, result = mq.dequeue(
|
||||
timeout=dequeue_timeout, cancel=shutdown_event
|
||||
)
|
||||
except TimeoutError as e:
|
||||
@@ -391,17 +441,26 @@ class UnreadyWorkerProcHandle:
|
||||
class WorkerProcHandle:
|
||||
proc: BaseProcess
|
||||
rank: int
|
||||
worker_response_mq: MessageQueue # The worker process writes to this MQ
|
||||
# The worker process writes to this MQ in single-node mode
|
||||
worker_response_mq: MessageQueue | None
|
||||
# This is only non empty on driver node,
|
||||
# the peer worker process i writes to MQ
|
||||
# `peer_worker_response_mqs[i]`
|
||||
peer_worker_response_mqs: list[MessageQueue | None]
|
||||
death_writer: Connection | None = None
|
||||
|
||||
@classmethod
|
||||
def from_unready_handle(
|
||||
cls, unready_handle: UnreadyWorkerProcHandle, worker_response_mq: MessageQueue
|
||||
cls,
|
||||
unready_handle: UnreadyWorkerProcHandle,
|
||||
worker_response_mq: MessageQueue | None,
|
||||
peer_worker_response_mqs: list[MessageQueue | None],
|
||||
) -> "WorkerProcHandle":
|
||||
return cls(
|
||||
proc=unready_handle.proc,
|
||||
rank=unready_handle.rank,
|
||||
worker_response_mq=worker_response_mq,
|
||||
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||
death_writer=unready_handle.death_writer,
|
||||
)
|
||||
|
||||
@@ -411,6 +470,38 @@ class WorkerProc:
|
||||
|
||||
READY_STR = "READY"
|
||||
|
||||
def _init_message_queues(
|
||||
self, input_shm_handle: Handle, vllm_config: VllmConfig
|
||||
) -> None:
|
||||
if vllm_config.parallel_config.nnodes_within_dp == 1:
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank
|
||||
)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq: MessageQueue = MessageQueue(1, 1)
|
||||
self.peer_response_handles = []
|
||||
else:
|
||||
# Initialize remote MessageQueue for receiving SchedulerOutput across nodes
|
||||
self.rpc_broadcast_mq = get_inner_dp_world_group().create_mq_broadcaster(
|
||||
external_writer_handle=input_shm_handle,
|
||||
# Since there is external_writer_handle from executor proc,
|
||||
# where the ready signal from actual writer is sent out of the
|
||||
# create_mq_broadcaster method and after this setup, we make it
|
||||
# non blocking. The handshake will be triggered when
|
||||
# worker.rpc_broadcast_mq.wait_until_ready() is called
|
||||
blocking=False,
|
||||
)
|
||||
# Initializes remote message queue for sending the model output to the
|
||||
# driver worker, exposing peer_response_handles for driver worker
|
||||
# that include handles for all ranks
|
||||
self.worker_response_mq, self.peer_response_handles = (
|
||||
get_inner_dp_world_group().create_single_reader_mq_broadcasters(
|
||||
reader_rank_in_group=0
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@@ -421,13 +512,15 @@ class WorkerProc:
|
||||
shared_worker_lock: LockType,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
wrapper = WorkerWrapperBase(
|
||||
vllm_config=vllm_config, rpc_rank=local_rank, global_rank=rank
|
||||
)
|
||||
# TODO: move `init_worker` to executor level as a collective rpc call
|
||||
all_kwargs: list[dict] = [
|
||||
{} for _ in range(vllm_config.parallel_config.world_size)
|
||||
]
|
||||
is_driver_worker = rank % vllm_config.parallel_config.tensor_parallel_size == 0
|
||||
all_kwargs[rank] = {
|
||||
all_kwargs[local_rank] = {
|
||||
"vllm_config": vllm_config,
|
||||
"local_rank": local_rank,
|
||||
"rank": rank,
|
||||
@@ -438,14 +531,6 @@ class WorkerProc:
|
||||
wrapper.init_worker(all_kwargs)
|
||||
self.worker = wrapper
|
||||
|
||||
# Initialize MessageQueue for receiving SchedulerOutput
|
||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||
input_shm_handle, self.worker.rank
|
||||
)
|
||||
|
||||
# Initializes a message queue for sending the model output
|
||||
self.worker_response_mq = MessageQueue(1, 1)
|
||||
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.use_async_scheduling = scheduler_config.async_scheduling
|
||||
if self.use_async_scheduling:
|
||||
@@ -466,6 +551,7 @@ class WorkerProc:
|
||||
)
|
||||
|
||||
# Load model
|
||||
self._init_message_queues(input_shm_handle, vllm_config)
|
||||
self.worker.load_model()
|
||||
|
||||
# Enable environment variable cache (e.g. assume no more
|
||||
@@ -512,6 +598,27 @@ class WorkerProc:
|
||||
# death_reader in child will get EOFError
|
||||
return UnreadyWorkerProcHandle(proc, rank, reader, death_writer)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_response_handle_ready(
|
||||
handles: dict[str, Any], proc_handle: UnreadyWorkerProcHandle
|
||||
) -> WorkerProcHandle:
|
||||
response_handle = handles["handle"]
|
||||
worker_response_mq: MessageQueue | None = None
|
||||
if len(response_handle.local_reader_ranks) > 0:
|
||||
worker_response_mq = MessageQueue.create_from_handle(response_handle, 0)
|
||||
peer_response_handles = handles["peer_response_handles"]
|
||||
peer_worker_response_mqs = [
|
||||
MessageQueue.create_from_handle(handle, -1)
|
||||
if handle.remote_subscribe_addr is not None
|
||||
else None
|
||||
for handle in peer_response_handles
|
||||
]
|
||||
return WorkerProcHandle.from_unready_handle(
|
||||
proc_handle,
|
||||
worker_response_mq,
|
||||
peer_worker_response_mqs=peer_worker_response_mqs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def wait_for_ready(
|
||||
unready_proc_handles: list[UnreadyWorkerProcHandle],
|
||||
@@ -537,16 +644,10 @@ class WorkerProc:
|
||||
if response["status"] != "READY":
|
||||
raise e
|
||||
|
||||
# Extract the message queue handle.
|
||||
worker_response_mq = MessageQueue.create_from_handle(
|
||||
response["handle"], 0
|
||||
idx = unready_proc_handle.rank % len(ready_proc_handles)
|
||||
ready_proc_handles[idx] = WorkerProc.wait_for_response_handle_ready(
|
||||
response, unready_proc_handle
|
||||
)
|
||||
ready_proc_handles[unready_proc_handle.rank] = (
|
||||
WorkerProcHandle.from_unready_handle(
|
||||
unready_proc_handle, worker_response_mq
|
||||
)
|
||||
)
|
||||
|
||||
except EOFError:
|
||||
e.__suppress_context__ = True
|
||||
raise e from None
|
||||
@@ -618,12 +719,14 @@ class WorkerProc:
|
||||
{
|
||||
"status": WorkerProc.READY_STR,
|
||||
"handle": worker.worker_response_mq.export_handle(),
|
||||
"peer_response_handles": worker.peer_response_handles,
|
||||
}
|
||||
)
|
||||
|
||||
# Ensure message queues are ready. Will deadlock if re-ordered.
|
||||
# Must be kept consistent with the Executor
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
if worker.rpc_broadcast_mq is not None:
|
||||
worker.rpc_broadcast_mq.wait_until_ready()
|
||||
worker.worker_response_mq.wait_until_ready()
|
||||
ready_writer.close()
|
||||
ready_writer = None
|
||||
|
||||
@@ -189,6 +189,7 @@ class Worker(WorkerBase):
|
||||
and self.parallel_config.distributed_executor_backend
|
||||
not in ["ray", "external_launcher"]
|
||||
and self.vllm_config.parallel_config.data_parallel_backend != "ray"
|
||||
and self.vllm_config.parallel_config.nnodes_within_dp == 1
|
||||
):
|
||||
# Use local DP rank if available, otherwise use global DP rank.
|
||||
dp_local_rank = self.parallel_config.data_parallel_rank_local
|
||||
@@ -205,7 +206,14 @@ class Worker(WorkerBase):
|
||||
assert self.local_rank < torch.cuda.device_count(), (
|
||||
f"DP adjusted local rank {self.local_rank} is out of bounds. "
|
||||
)
|
||||
|
||||
visible_device_count = (
|
||||
torch.cuda.device_count() if torch.cuda.is_available() else 0
|
||||
)
|
||||
assert self.parallel_config.local_world_size <= visible_device_count, (
|
||||
f"local_world_size ({self.parallel_config.local_world_size}) must be "
|
||||
f"less than or equal to the number of visible devices "
|
||||
f"({visible_device_count})."
|
||||
)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
current_platform.set_device(self.device)
|
||||
|
||||
|
||||
@@ -180,6 +180,7 @@ class WorkerWrapperBase:
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
rpc_rank: int = 0,
|
||||
global_rank: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker wrapper with the given vllm_config and rpc_rank.
|
||||
@@ -192,6 +193,7 @@ class WorkerWrapperBase:
|
||||
group.
|
||||
"""
|
||||
self.rpc_rank = rpc_rank
|
||||
self.global_rank = self.rpc_rank if global_rank is None else global_rank
|
||||
self.worker: WorkerBase | None = None
|
||||
|
||||
# do not store this `vllm_config`, `init_worker` will set the final
|
||||
@@ -312,7 +314,7 @@ class WorkerWrapperBase:
|
||||
assert self.worker is not None
|
||||
|
||||
def initialize_from_config(self, kv_cache_configs: list[Any]) -> None:
|
||||
kv_cache_config = kv_cache_configs[self.rpc_rank]
|
||||
kv_cache_config = kv_cache_configs[self.global_rank]
|
||||
with set_current_vllm_config(self.vllm_config):
|
||||
self.worker.initialize_from_config(kv_cache_config) # type: ignore
|
||||
|
||||
|
||||
Reference in New Issue
Block a user