[Neuron] [Bugfix] Fix neuron startup (#9374)

Co-authored-by: Jerzy Zagorski <jzagorsk@amazon.com>
This commit is contained in:
xendo
2024-10-22 14:51:41 +02:00
committed by GitHub
parent a48e3ec052
commit 9dbcce84a7
7 changed files with 37 additions and 18 deletions

View File

@@ -58,6 +58,13 @@ try:
except Exception:
pass
is_neuron = False
try:
import transformers_neuronx # noqa: F401
is_neuron = True
except ImportError:
pass
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
@@ -75,6 +82,9 @@ elif is_xpu:
elif is_cpu:
from .cpu import CpuPlatform
current_platform = CpuPlatform()
elif is_neuron:
from .neuron import NeuronPlatform
current_platform = NeuronPlatform()
else:
current_platform = UnspecifiedPlatform()

View File

@@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
TPU = enum.auto()
XPU = enum.auto()
CPU = enum.auto()
NEURON = enum.auto()
UNSPECIFIED = enum.auto()
@@ -48,6 +49,9 @@ class Platform:
def is_cpu(self) -> bool:
return self._enum == PlatformEnum.CPU
def is_neuron(self) -> bool:
return self._enum == PlatformEnum.NEURON
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

9
vllm/platforms/neuron.py Normal file
View File

@@ -0,0 +1,9 @@
from .interface import Platform, PlatformEnum
class NeuronPlatform(Platform):
_enum = PlatformEnum.NEURON
@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
return "neuron"