Ray placement group support (#397)
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional, TYPE_CHECKING
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -13,7 +14,13 @@ from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
if ray:
|
||||
from ray.air.util.torch_dist import init_torch_dist_process_group
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -54,7 +61,7 @@ class LLMEngine:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
distributed_init_method: str,
|
||||
stage_devices: List[List[DeviceID]],
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
logger.info(
|
||||
@@ -85,31 +92,73 @@ class LLMEngine:
|
||||
self.seq_counter = Counter()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self.workers: List[Worker] = []
|
||||
assert len(stage_devices) == 1, "Only support one stage for now."
|
||||
for rank, node_resource, _ in stage_devices[0]:
|
||||
worker_cls = Worker
|
||||
if self.parallel_config.worker_use_ray:
|
||||
worker_cls = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
resources={node_resource: 1e-3},
|
||||
)(worker_cls).remote
|
||||
if self.parallel_config.worker_use_ray:
|
||||
self._init_workers_ray(placement_group)
|
||||
else:
|
||||
self._init_workers(distributed_init_method)
|
||||
|
||||
worker = worker_cls(
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
# Profile the memory usage and initialize the cache.
|
||||
self._init_cache()
|
||||
|
||||
# Create the scheduler.
|
||||
self.scheduler = Scheduler(scheduler_config, cache_config, log_stats)
|
||||
|
||||
def _init_workers(self, distributed_init_method: str):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
0,
|
||||
distributed_init_method,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup"):
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
for bundle in placement_group.bundle_specs:
|
||||
if not bundle.get("GPU", 0):
|
||||
continue
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=1,
|
||||
scheduling_strategy=PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True),
|
||||
)(RayWorker).remote()
|
||||
self.workers.append(worker)
|
||||
|
||||
# Initialize torch distributed process group for the workers.
|
||||
init_torch_dist_process_group(self.workers, backend="nccl")
|
||||
self._run_workers("init_worker",
|
||||
get_all_outputs=True,
|
||||
worker_init_fn=lambda: Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
None,
|
||||
None,
|
||||
))
|
||||
self._run_workers(
|
||||
"init_model",
|
||||
get_all_outputs=True,
|
||||
)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
self.model_config.verify_with_parallel_config(self.parallel_config)
|
||||
self.cache_config.verify_with_parallel_config(self.parallel_config)
|
||||
@@ -152,11 +201,12 @@ class LLMEngine:
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
distributed_init_method, devices = initialize_cluster(parallel_config)
|
||||
distributed_init_method, placement_group = initialize_cluster(
|
||||
parallel_config)
|
||||
# Create the LLM engine.
|
||||
engine = cls(*engine_configs,
|
||||
distributed_init_method,
|
||||
devices,
|
||||
placement_group,
|
||||
log_stats=not engine_args.disable_log_stats)
|
||||
return engine
|
||||
|
||||
@@ -326,9 +376,10 @@ class LLMEngine:
|
||||
"""Runs the given method on all workers."""
|
||||
all_outputs = []
|
||||
for worker in self.workers:
|
||||
executor = getattr(worker, method)
|
||||
if self.parallel_config.worker_use_ray:
|
||||
executor = executor.remote
|
||||
executor = partial(worker.execute_method.remote, method)
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
|
||||
output = executor(*args, **kwargs)
|
||||
all_outputs.append(output)
|
||||
|
||||
Reference in New Issue
Block a user