[Hardware][Intel GPU] Add intel GPU pipeline parallel support. (#7810)
This commit is contained in:
@@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
uses_ray: bool = False
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self._check_executor_parameters()
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
world_size = self.parallel_config.world_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
# Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
|
||||
os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
|
||||
|
||||
@@ -68,16 +64,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
if world_size > 1:
|
||||
maybe_set_triton_cache_manager()
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
assert world_size <= cuda_device_count, (
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
# Multiprocessing-based executor does not support multi-node setting.
|
||||
# Since it only works for single node, we can use the loopback address
|
||||
# 127.0.0.1 for communication.
|
||||
@@ -139,6 +125,26 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
def _check_executor_parameters(self):
|
||||
world_size = self.parallel_config.tensor_parallel_size
|
||||
tensor_parallel_size = self.parallel_config.tensor_parallel_size
|
||||
|
||||
# Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers
|
||||
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
||||
update_environment_variables({
|
||||
"CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size))))
|
||||
})
|
||||
|
||||
cuda_device_count = cuda_device_count_stateless()
|
||||
# Use confusing message for more common TP-only case.
|
||||
assert tensor_parallel_size <= cuda_device_count, (
|
||||
f"please set tensor_parallel_size ({tensor_parallel_size}) "
|
||||
f"to less than max local gpu count ({cuda_device_count})")
|
||||
|
||||
assert world_size <= cuda_device_count, (
|
||||
f"please ensure that world_size ({world_size}) "
|
||||
f"is less than than max local gpu count ({cuda_device_count})")
|
||||
|
||||
def shutdown(self):
|
||||
if (worker_monitor := getattr(self, "worker_monitor",
|
||||
None)) is not None:
|
||||
|
||||
26
vllm/executor/multiproc_xpu_executor.py
Normal file
26
vllm/executor/multiproc_xpu_executor.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.multiproc_gpu_executor import (
|
||||
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync)
|
||||
from vllm.executor.xpu_executor import XPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_async
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor):
|
||||
"""Python multiprocessing-based multi-XPU executor"""
|
||||
|
||||
def _check_executor_parameters(self):
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
if mp_method != "spawn":
|
||||
raise RuntimeError(
|
||||
"XPU multiprocess executor only support spawn as mp method")
|
||||
|
||||
|
||||
class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor,
|
||||
MultiprocessingGPUExecutorAsync):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.driver_exec_model = make_async(self.driver_worker.execute_model)
|
||||
Reference in New Issue
Block a user