[Bugfix] neuron: enable tensor parallelism (#7562)

Signed-off-by: omrishiv <327609+omrishiv@users.noreply.github.com>
This commit is contained in:
omrishiv
2024-08-26 15:13:13 -07:00
committed by GitHub
parent 05826c887b
commit 760e9f71a8
3 changed files with 44 additions and 11 deletions

View File

@@ -4,7 +4,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import make_async
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
@@ -24,14 +25,17 @@ class NeuronExecutor(ExecutorBase):
def _init_worker(self):
from vllm.worker.neuron_worker import NeuronWorker
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = NeuronWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
)
model_config=self.model_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
device_config=self.device_config,
cache_config=self.cache_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method)
self.driver_worker.init_device()
self.driver_worker.load_model()