[Misc] Auto fallback to float16 for pre-Ampere GPUs when detected bfloat16 config (#17265)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -7,7 +7,6 @@ import hashlib
|
|||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
import warnings
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
@@ -34,7 +33,7 @@ from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
|||||||
QuantizationMethods,
|
QuantizationMethods,
|
||||||
get_quantization_config)
|
get_quantization_config)
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||||
@@ -2988,6 +2987,7 @@ def _get_and_verify_dtype(
|
|||||||
if isinstance(dtype, str):
|
if isinstance(dtype, str):
|
||||||
dtype = dtype.lower()
|
dtype = dtype.lower()
|
||||||
if dtype == "auto":
|
if dtype == "auto":
|
||||||
|
# Set default dtype from model config
|
||||||
if config_dtype == torch.float32:
|
if config_dtype == torch.float32:
|
||||||
# Following common practice, we use float16 for float32 models
|
# Following common practice, we use float16 for float32 models
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
@@ -2995,37 +2995,33 @@ def _get_and_verify_dtype(
|
|||||||
torch_dtype = config_dtype
|
torch_dtype = config_dtype
|
||||||
|
|
||||||
if config.model_type == "plamo2":
|
if config.model_type == "plamo2":
|
||||||
logger.info(
|
logger.warning(
|
||||||
"For PLaMo2, we cast models to bfloat16 instead of using "
|
"For PLaMo2, we cast models to bfloat16 instead of using "
|
||||||
"float16 by default. This is because float16 does not work."
|
"float16 by default. This is because float16 does not work."
|
||||||
)
|
)
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# Deal with torch dtype fallback for device compatibility.
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
if (current_platform.is_cpu()
|
if torch_dtype not in current_platform.supported_dtypes:
|
||||||
and current_platform.get_cpu_architecture()
|
device_name = current_platform.get_device_name()
|
||||||
== CpuArchEnum.POWERPC
|
|
||||||
and (config_dtype == torch.float16
|
|
||||||
or config_dtype == torch.float32)):
|
|
||||||
logger.info(
|
|
||||||
"For POWERPC, we cast models to bfloat16 instead of "
|
|
||||||
"using float16 by default. Float16 is not currently "
|
|
||||||
"supported for POWERPC.")
|
|
||||||
torch_dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# TODO: change this condition to check if the platform support bf16
|
if ((capability := current_platform.get_device_capability())
|
||||||
# instead of checking the OS. For instance M2 shall supports bf16
|
is None):
|
||||||
# already. But we need to modify `cpu_extension.cmake` to activate
|
compute_str = ""
|
||||||
# the feature in the build.
|
else:
|
||||||
if (current_platform.is_cpu() and sys.platform.startswith("darwin")
|
version_str = capability.as_version_str()
|
||||||
and current_platform.get_cpu_architecture()
|
compute_str = f" (with compute capability {version_str})"
|
||||||
== CpuArchEnum.ARM and config_dtype == torch.bfloat16):
|
fallback_dtype = current_platform.supported_dtypes[0]
|
||||||
logger.info("For macOS with Apple Silicon, currently bfloat16 "
|
logger.warning(
|
||||||
"is not supported. Setting dtype to float16.")
|
"Your %s device%s doesn't support %s. " \
|
||||||
torch_dtype = torch.float16
|
"Falling back to %s for compatibility.",
|
||||||
|
device_name, compute_str, torch_dtype, fallback_dtype
|
||||||
|
)
|
||||||
|
torch_dtype = fallback_dtype
|
||||||
|
|
||||||
if current_platform.is_hpu() and config_dtype == torch.float16:
|
if current_platform.is_hpu() and torch_dtype == torch.float16:
|
||||||
logger.info(
|
logger.warning(
|
||||||
"For HPU, we cast models to bfloat16 instead of "
|
"For HPU, we cast models to bfloat16 instead of "
|
||||||
"using float16 by default. Please specify `dtype` if you "
|
"using float16 by default. Please specify `dtype` if you "
|
||||||
"want to use float16.")
|
"want to use float16.")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum, _Backend
|
from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -26,6 +26,20 @@ class CpuPlatform(Platform):
|
|||||||
device_type: str = "cpu"
|
device_type: str = "cpu"
|
||||||
dispatch_key: str = "CPU"
|
dispatch_key: str = "CPU"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_dtypes(self) -> list:
|
||||||
|
if self.get_cpu_architecture() == CpuArchEnum.POWERPC:
|
||||||
|
return [torch.bfloat16, torch.float32]
|
||||||
|
elif sys.platform.startswith(
|
||||||
|
"darwin") and self.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||||
|
# TODO: change this condition to check if the platform support bf16
|
||||||
|
# instead of checking the OS. For instance M2 shall supports bf16
|
||||||
|
# already. But we need to modify `cpu_extension.cmake` to activate
|
||||||
|
# the feature in the build.
|
||||||
|
return [torch.bfloat16, torch.float32]
|
||||||
|
# x86/aarch64 CPU has supported both bf16 and fp16 natively.
|
||||||
|
return [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|||||||
@@ -73,6 +73,19 @@ class CudaPlatformBase(Platform):
|
|||||||
ray_device_key: str = "GPU"
|
ray_device_key: str = "GPU"
|
||||||
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_dtypes(self) -> List[torch.dtype]:
|
||||||
|
if self.has_device_capability(80):
|
||||||
|
# Ampere and Hopper or later NVIDIA GPUs.
|
||||||
|
return [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
elif (not self.has_device_capability(80)
|
||||||
|
) and self.has_device_capability(60):
|
||||||
|
# Pascal, Volta and Turing NVIDIA GPUs, BF16 is not supported
|
||||||
|
return [torch.float16, torch.float32]
|
||||||
|
# Kepler and Maxwell NVIDIA GPUs, only FP32 is supported,
|
||||||
|
# though vLLM doesn't support these GPUs.
|
||||||
|
return [torch.float32]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_capability(cls,
|
def get_device_capability(cls,
|
||||||
device_id: int = 0
|
device_id: int = 0
|
||||||
|
|||||||
@@ -122,6 +122,14 @@ class Platform:
|
|||||||
|
|
||||||
additional_env_vars: list[str] = []
|
additional_env_vars: list[str] = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_dtypes(self) -> list[torch.dtype]:
|
||||||
|
"""Returns the supported dtypes for the current platform."""
|
||||||
|
# Be careful with the order of the dtypes. The first dtype will
|
||||||
|
# be used as the default dtype fallback for the current platform,
|
||||||
|
# when encountering unsupported dtypes in "auto" dtype.
|
||||||
|
return [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
def is_cuda(self) -> bool:
|
def is_cuda(self) -> bool:
|
||||||
return self._enum == PlatformEnum.CUDA
|
return self._enum == PlatformEnum.CUDA
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user