Remove hard-dependencies of Speculative decode to CUDA workers (#10587)
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
This commit is contained in:
@@ -86,4 +86,10 @@ class CpuPlatform(Platform):
|
||||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "mp"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||
if vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.cpu_worker.CPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker"
|
||||
|
||||
@@ -106,6 +106,8 @@ class CudaPlatformBase(Platform):
|
||||
elif vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
@@ -236,4 +238,4 @@ try:
|
||||
if not isinstance(pynvml, _MockModule):
|
||||
CudaPlatform.log_warnings()
|
||||
except ModuleNotFoundError:
|
||||
CudaPlatform.log_warnings()
|
||||
CudaPlatform.log_warnings()
|
||||
Reference in New Issue
Block a user