[Neuron] [Bugfix] Fix neuron startup (#9374)
Co-authored-by: Jerzy Zagorski <jzagorsk@amazon.com>
This commit is contained in:
@@ -58,6 +58,13 @@ try:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
is_neuron = False
|
||||
try:
|
||||
import transformers_neuronx # noqa: F401
|
||||
is_neuron = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if is_tpu:
|
||||
# people might install pytorch built with cuda but run on tpu
|
||||
# so we need to check tpu first
|
||||
@@ -75,6 +82,9 @@ elif is_xpu:
|
||||
elif is_cpu:
|
||||
from .cpu import CpuPlatform
|
||||
current_platform = CpuPlatform()
|
||||
elif is_neuron:
|
||||
from .neuron import NeuronPlatform
|
||||
current_platform = NeuronPlatform()
|
||||
else:
|
||||
current_platform = UnspecifiedPlatform()
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ class PlatformEnum(enum.Enum):
|
||||
TPU = enum.auto()
|
||||
XPU = enum.auto()
|
||||
CPU = enum.auto()
|
||||
NEURON = enum.auto()
|
||||
UNSPECIFIED = enum.auto()
|
||||
|
||||
|
||||
@@ -48,6 +49,9 @@ class Platform:
|
||||
def is_cpu(self) -> bool:
|
||||
return self._enum == PlatformEnum.CPU
|
||||
|
||||
def is_neuron(self) -> bool:
|
||||
return self._enum == PlatformEnum.NEURON
|
||||
|
||||
def is_cuda_alike(self) -> bool:
|
||||
"""Stateless version of :func:`torch.cuda.is_available`."""
|
||||
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
|
||||
|
||||
9
vllm/platforms/neuron.py
Normal file
9
vllm/platforms/neuron.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
|
||||
class NeuronPlatform(Platform):
|
||||
_enum = PlatformEnum.NEURON
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
return "neuron"
|
||||
Reference in New Issue
Block a user