[Speculative decoding 2/9] Multi-step worker for draft model (#2424)
This commit is contained in:
@@ -18,7 +18,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
|
||||
|
||||
if ray:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@@ -132,7 +132,8 @@ class LLMEngine:
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
@@ -207,7 +208,8 @@ class LLMEngine:
|
||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
||||
|
||||
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port)
|
||||
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
|
||||
@@ -65,10 +65,9 @@ def initialize_cluster(
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||
`distributed_init_method` is the address for initializing the
|
||||
distributed backend. `placement_group` includes the specification
|
||||
of the resources for each distributed worker.
|
||||
An optional `PlacementGroup`. It includes the specification
|
||||
of the resources for each distributed worker. None if Ray is
|
||||
not used.
|
||||
"""
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
|
||||
Reference in New Issue
Block a user