[Core] Add MultiprocessingGPUExecutor (#4539)

Co-authored-by: SAHIL SUNEJA <suneja@us.ibm.com>
This commit is contained in:
Nick Hill
2024-05-14 10:38:59 -07:00
committed by GitHub
parent dc72402b57
commit 676a99982f
11 changed files with 225 additions and 39 deletions

View File

@@ -348,27 +348,31 @@ class AsyncLLMEngine:
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
distributed_executor_backend = (
engine_config.parallel_config.distributed_executor_backend)
if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "cpu":
assert not engine_config.parallel_config.worker_use_ray, (
"Ray is not supported with the CPU backend.")
assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.")
from vllm.executor.cpu_executor import CPUExecutorAsync
executor_class = CPUExecutorAsync
elif engine_config.parallel_config.worker_use_ray:
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
from vllm.executor.multiproc_gpu_executor import (
MultiprocessingGPUExecutorAsync)
executor_class = MultiprocessingGPUExecutorAsync
else:
assert engine_config.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(
engine_config.parallel_config.worker_use_ray,
distributed_executor_backend == "ray",
engine_args.engine_use_ray,
**engine_config.to_dict(),
executor_class=executor_class,