diff --git a/vllm/config.py b/vllm/config.py index d986ab6b0..7a9bc8a4f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2285,7 +2285,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"] class DeviceConfig: """Configuration for the device to use for vLLM execution.""" - device: SkipValidation[Union[Device, torch.device]] = "auto" + device: SkipValidation[Optional[Union[Device, torch.device]]] = "auto" """Device type for vLLM execution. This parameter is deprecated and will be removed in a future release. @@ -2327,7 +2327,10 @@ class DeviceConfig: "to turn on verbose logging to help debug the issue.") else: # Device type is assigned explicitly - self.device_type = self.device + if isinstance(self.device, str): + self.device_type = self.device + elif isinstance(self.device, torch.device): + self.device_type = self.device.type # Some device types require processing inputs on CPU if self.device_type in ["neuron"]: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f599d7a3b..a0e099a19 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1018,7 +1018,8 @@ class EngineArgs: from vllm.platforms import current_platform current_platform.pre_register_and_update() - device_config = DeviceConfig(device=current_platform.device_type) + device_config = DeviceConfig( + device=cast(Device, current_platform.device_type)) model_config = self.create_model_config() # * If VLLM_USE_V1 is unset, we enable V1 for "supported features"