[Neuron] Support inference with transformers-neuronx (#2569)
This commit is contained in:
@@ -8,7 +8,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
|
||||
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -380,13 +380,21 @@ class ParallelConfig:
|
||||
disable_custom_all_reduce: bool = False,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
if is_neuron():
|
||||
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
|
||||
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
|
||||
# to multiple NeuronCores.
|
||||
self.tensor_parallel_size = 1
|
||||
self.neuron_tp_degree = tensor_parallel_size
|
||||
else:
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
|
||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
||||
if self.world_size > 1:
|
||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||
# Ray worker is not supported for Neuron backend.
|
||||
if self.world_size > 1 and not is_neuron():
|
||||
self.worker_use_ray = True
|
||||
self._verify_args()
|
||||
|
||||
@@ -465,8 +473,29 @@ class SchedulerConfig:
|
||||
|
||||
class DeviceConfig:
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
self.device = torch.device(device)
|
||||
def __init__(self, device: str = "auto") -> None:
|
||||
if device == "auto":
|
||||
# Automated device type detection
|
||||
if torch.cuda.is_available():
|
||||
self.device_type = "cuda"
|
||||
elif is_neuron():
|
||||
self.device_type = "neuron"
|
||||
else:
|
||||
raise RuntimeError("No supported device detected.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
self.device_type = device
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["neuron"]:
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
@property
|
||||
def is_neuron(self):
|
||||
return self.device_type == "neuron"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user