[Core] Rework dtype resolution (#18751)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-06-01 11:04:23 +08:00
committed by GitHub
parent 1bc86a3da1
commit 6aa8f9a4e7
13 changed files with 314 additions and 119 deletions

View File

@@ -24,6 +24,7 @@ import torch
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
model_validator)
from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated, runtime_checkable
@@ -42,15 +43,16 @@ from vllm.transformers_utils.config import (
ConfigFormat, get_config, get_hf_image_processor_config,
get_hf_text_config, get_pooling_config,
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
try_get_generation_config, uses_mrope)
try_get_generation_config, try_get_safetensors_metadata, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)
LayerBlockType, common_broadcastable_dtype,
cuda_device_count_stateless, get_cpu_memory,
get_open_port, is_torch_equal_or_newer, random_uuid,
resolve_obj_by_qualname)
if TYPE_CHECKING:
from _typeshed import DataclassInstance
@@ -540,7 +542,24 @@ class ModelConfig:
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task = task
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config()
self.dtype = _get_and_verify_dtype(
self.model,
self.hf_config,
self.dtype,
is_pooling_model=self.runner_type == "pooling",
revision=self.revision,
)
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config. TODO: remove this
@@ -597,16 +616,6 @@ class ModelConfig:
raise ValueError(
"`override_neuron_config` is only supported on Neuron.")
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task = task
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config()
self._verify_quantization()
self._verify_cuda_graph()
self._verify_bnb_config()
@@ -692,7 +701,6 @@ class ModelConfig:
self.model, self.revision)
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(
@@ -3074,13 +3082,37 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
}
_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] #
# model_type -> reason
_FLOAT16_NOT_SUPPORTED_MODELS = {
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
}
def _get_and_verify_dtype(
def _is_valid_dtype(model_type: str, dtype: torch.dtype):
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103
return False
return True
def _check_valid_dtype(model_type: str, dtype: torch.dtype):
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16:
reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type]
raise ValueError(f"The model type {model_type!r} "
f"does not support float16. Reason: {reason}")
return True
def _find_dtype(
model_id: str,
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
) -> torch.dtype:
*,
revision: Optional[str],
):
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None)
@@ -3092,75 +3124,111 @@ def _get_and_verify_dtype(
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
# Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None:
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata):
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
for file_mt in files_mt.values()
for dtype_str in file_mt.parameter_count
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
}
if param_dtypes:
return common_broadcastable_dtype(param_dtypes)
if config_dtype is None:
config_dtype = torch.float32
return config_dtype
def _resolve_auto_dtype(
model_type: str,
config_dtype: torch.dtype,
*,
is_pooling_model: bool,
):
from vllm.platforms import current_platform
supported_dtypes = [
dtype for dtype in current_platform.supported_dtypes
if _is_valid_dtype(model_type, dtype)
]
if is_pooling_model and torch.float16 in supported_dtypes:
preferred_dtype = torch.float16
else:
preferred_dtype = supported_dtypes[0]
# Downcast for float32 models
if config_dtype == torch.float32:
config_dtype = preferred_dtype
if config_dtype in supported_dtypes:
return config_dtype
# Ensure device compatibility
device_name = current_platform.get_device_name()
device_capability = current_platform.get_device_capability()
if device_capability is None:
device_str = f"{device_name!r}"
else:
version_str = device_capability.as_version_str()
device_str = f"{device_name!r} (with compute capability {version_str})"
logger.warning(
"Your device %s doesn't support %s. "
"Falling back to %s for compatibility.",
device_str,
config_dtype,
preferred_dtype,
)
return preferred_dtype
def _get_and_verify_dtype(
model_id: str,
config: PretrainedConfig,
dtype: Union[str, torch.dtype],
*,
is_pooling_model: bool,
revision: Optional[str] = None,
) -> torch.dtype:
config_dtype = _find_dtype(model_id, config, revision=revision)
model_type = config.model_type
if isinstance(dtype, str):
dtype = dtype.lower()
if dtype == "auto":
# Set default dtype from model config
if config_dtype == torch.float32:
# Following common practice, we use float16 for float32 models
torch_dtype = torch.float16
else:
torch_dtype = config_dtype
if config.model_type == "plamo2":
logger.warning(
"For PLaMo2, we cast models to bfloat16 instead of using "
"float16 by default. This is because float16 does not work."
)
torch_dtype = torch.bfloat16
# Deal with torch dtype fallback for device compatibility.
from vllm.platforms import current_platform
if torch_dtype not in current_platform.supported_dtypes:
device_name = current_platform.get_device_name()
if ((capability := current_platform.get_device_capability())
is None):
compute_str = ""
else:
version_str = capability.as_version_str()
compute_str = f" (with compute capability {version_str})"
fallback_dtype = current_platform.supported_dtypes[0]
logger.warning(
"Your %s device%s doesn't support %s. " \
"Falling back to %s for compatibility.",
device_name, compute_str, torch_dtype, fallback_dtype
)
torch_dtype = fallback_dtype
if current_platform.is_hpu() and torch_dtype == torch.float16:
logger.warning(
"For HPU, we cast models to bfloat16 instead of "
"using float16 by default. Please specify `dtype` if you "
"want to use float16.")
torch_dtype = torch.bfloat16
elif dtype == "float16" and config.model_type == "plamo2":
logger.warning(
"For PLaMo2, using float16 is unstable and might cause "
"unexpected behavior. Please use bfloat16 or float32 instead.")
torch_dtype = torch.float16
torch_dtype = _resolve_auto_dtype(
model_type,
config_dtype,
is_pooling_model=is_pooling_model,
)
else:
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {dtype}")
raise ValueError(f"Unknown dtype: {dtype!r}")
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
elif isinstance(dtype, torch.dtype):
torch_dtype = dtype
else:
raise ValueError(f"Unknown dtype: {dtype}")
# Verify the dtype.
_check_valid_dtype(model_type, torch_dtype)
if torch_dtype != config_dtype:
if torch_dtype == torch.float32:
# Upcasting to float32 is allowed.
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
pass
elif config_dtype == torch.float32:
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
pass
else:
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)