[Speculative decoding][Re-take] Enable TP>1 speculative decoding (#4840)

Co-authored-by: Cade Daniel <edacih@gmail.com>
Co-authored-by: Cade Daniel <cade@anyscale.com>
This commit is contained in:
Cody Yu
2024-05-16 00:53:51 -07:00
committed by GitHub
parent 30e754390c
commit 973617ae02
12 changed files with 295 additions and 180 deletions

View File

@@ -28,9 +28,6 @@ USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor):
def _init_executor(self) -> None:
assert (not self.speculative_config
), "Speculative decoding not yet supported for RayGPU backend."
assert self.parallel_config.distributed_executor_backend == "ray"
placement_group = self.parallel_config.placement_group
@@ -90,14 +87,22 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote(
num_cpus=0,
num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs,
)(RayWorkerWrapper).remote(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
@@ -107,8 +112,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name="vllm.worker.worker",
worker_class_name="Worker",
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else: