41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
import asyncio
|
|
from typing import List, Optional
|
|
|
|
import ray
|
|
|
|
import vllm.envs as envs
|
|
from vllm.executor.ray_gpu_executor import RayGPUExecutor, RayGPUExecutorAsync
|
|
from vllm.executor.xpu_executor import XPUExecutor
|
|
from vllm.logger import init_logger
|
|
from vllm.utils import make_async
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class RayXPUExecutor(RayGPUExecutor, XPUExecutor):
|
|
|
|
def _get_env_vars_to_be_updated(self):
|
|
# Get the set of GPU IDs used on each node.
|
|
worker_node_and_gpu_ids = []
|
|
for worker in [self.driver_dummy_worker] + self.workers:
|
|
if worker is None:
|
|
# driver_dummy_worker can be None when using ray spmd worker.
|
|
continue
|
|
worker_node_and_gpu_ids.append(
|
|
ray.get(worker.get_node_and_gpu_ids.remote())) # type: ignore
|
|
|
|
# Set environment variables for the driver and workers.
|
|
all_args_to_update_environment_variables = [({
|
|
"VLLM_TRACE_FUNCTION":
|
|
str(envs.VLLM_TRACE_FUNCTION),
|
|
}, ) for (_, _) in worker_node_and_gpu_ids]
|
|
return all_args_to_update_environment_variables
|
|
|
|
|
|
class RayXPUExecutorAsync(RayXPUExecutor, RayGPUExecutorAsync):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.driver_exec_method = make_async(self.driver_worker.execute_method)
|
|
self.pp_locks: Optional[List[asyncio.Lock]] = None
|