[Hardware][Neuron] Refactor neuron support (#3471)
This commit is contained in:
@@ -3,7 +3,6 @@ import copy
|
||||
from collections import defaultdict
|
||||
import os
|
||||
import pickle
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
@@ -25,12 +24,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# A map between the device type (in device config) to its worker module.
|
||||
DEVICE_TO_WORKER_MODULE_MAP = {
|
||||
"cuda": "vllm.worker.worker",
|
||||
"neuron": "vllm.worker.neuron_worker",
|
||||
}
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
@@ -73,13 +66,6 @@ class RayGPUExecutor(ExecutorBase):
|
||||
if USE_RAY_COMPILED_DAG:
|
||||
self.forward_dag = self._compiled_ray_dag()
|
||||
|
||||
def _dispatch_worker(self):
|
||||
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
||||
self.device_config.device_type]
|
||||
imported_worker = importlib.import_module(worker_module)
|
||||
Worker = imported_worker.Worker
|
||||
return Worker
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
@@ -155,7 +141,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
Worker = self._dispatch_worker()
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
@@ -201,7 +187,7 @@ class RayGPUExecutor(ExecutorBase):
|
||||
|
||||
# FIXME(woosuk): We are not properly initializing cupy NCCL when
|
||||
# we have multiple nodes.
|
||||
self._run_workers("init_model",
|
||||
self._run_workers("init_device",
|
||||
cupy_port=get_open_port()
|
||||
if not model_config.enforce_eager else None)
|
||||
self._run_workers(
|
||||
|
||||
Reference in New Issue
Block a user