[V1] Multiprocessing Tensor Parallel Support for v1 (#9856)

Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
Tyler Michael Smith
2024-12-10 01:28:14 -05:00
committed by GitHub
parent bc192a2b09
commit 28b3a1c7e5
21 changed files with 732 additions and 145 deletions

View File

@@ -3,25 +3,19 @@ import os
from functools import partial
from typing import Any, List, Optional
import torch
from vllm.executor.distributed_gpu_executor import ( # yapf: disable
DistributedGPUExecutor, DistributedGPUExecutorAsync)
from vllm.executor.gpu_executor import create_worker
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor)
from vllm.executor.multiproc_worker_utils import (
ProcessWorkerWrapper, ResultHandler, WorkerMonitor,
set_multiprocessing_worker_envs)
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless,
cuda_is_initialized, get_distributed_init_method,
get_open_port, make_async,
get_distributed_init_method, get_open_port, make_async,
update_environment_variables)
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
logger = init_logger(__name__)
@@ -37,30 +31,8 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
world_size = self.parallel_config.world_size
tensor_parallel_size = self.parallel_config.tensor_parallel_size
# Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)
# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and world_size > 1:
maybe_set_triton_cache_manager()
# Set multiprocessing envs that are common to V0 and V1
set_multiprocessing_worker_envs(self.parallel_config)
# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
@@ -122,13 +94,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
})
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
cuda_device_count = cuda_device_count_stateless()
# Use confusing message for more common TP-only case.
assert tensor_parallel_size <= cuda_device_count, (

View File

@@ -11,8 +11,15 @@ from multiprocessing.process import BaseProcess
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
TypeVar, Union)
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import cuda_is_initialized
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
logger = init_logger(__name__)
@@ -270,3 +277,38 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
def get_mp_context():
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)
def set_multiprocessing_worker_envs(parallel_config):
""" Set up environment variables that should be used when there are workers
in a multiprocessing environment. This should be called by the parent
process before worker processes are created"""
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
# Configure thread parallelism if OMP_NUM_THREADS isn't set
#
# Helps to avoid CPU contention. The default of spawning a thread per
# core combined with multiprocessing for each GPU can have a negative
# impact on performance. The contention is amplified when running in a
# container where CPU limits can cause throttling.
default_omp_num_threads = 1
if "OMP_NUM_THREADS" not in os.environ and (
current_parallelism :=
torch.get_num_threads()) > default_omp_num_threads:
logger.warning(
"Reducing Torch parallelism from %d threads to %d to avoid "
"unnecessary CPU contention. Set OMP_NUM_THREADS in the "
"external environment to tune this value as needed.",
current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads)
# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and parallel_config.world_size > 1:
maybe_set_triton_cache_manager()