[V1] Multiprocessing Tensor Parallel Support for v1 (#9856)
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
committed by
GitHub
parent
bc192a2b09
commit
28b3a1c7e5
@@ -12,6 +12,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
@@ -110,17 +111,28 @@ class CudaPlatformBase(Platform):
|
||||
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_worker.MultiStepWorker"
|
||||
elif vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.gpu_worker.Worker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.worker.Worker"
|
||||
|
||||
|
||||
# NVML utils
|
||||
@@ -249,4 +261,4 @@ try:
|
||||
if not isinstance(pynvml, _MockModule):
|
||||
CudaPlatform.log_warnings()
|
||||
except ModuleNotFoundError:
|
||||
CudaPlatform.log_warnings()
|
||||
CudaPlatform.log_warnings()
|
||||
|
||||
Reference in New Issue
Block a user