[Neuron] Support inference with transformers-neuronx (#2569)

This commit is contained in:
Liangfu Chen
2024-02-28 09:34:34 -08:00
committed by GitHub
parent e46fa5d52e
commit 3b7178cfa4
18 changed files with 516 additions and 42 deletions

View File

@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
logger = init_logger(__name__)
@@ -380,13 +380,21 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
self.worker_use_ray = True
self._verify_args()
@@ -465,8 +473,29 @@ class SchedulerConfig:
class DeviceConfig:
def __init__(self, device: str = "cuda") -> None:
self.device = torch.device(device)
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
else:
# Device type is assigned explicitly
self.device_type = device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
else:
# Set device with device type
self.device = torch.device(self.device_type)
@property
def is_neuron(self):
return self.device_type == "neuron"
@dataclass