[Core] Rework dtype resolution (#18751)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
202
vllm/config.py
202
vllm/config.py
@@ -24,6 +24,7 @@ import torch
|
||||
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
|
||||
model_validator)
|
||||
from pydantic.dataclasses import dataclass
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated, runtime_checkable
|
||||
@@ -42,15 +43,16 @@ from vllm.transformers_utils.config import (
|
||||
ConfigFormat, get_config, get_hf_image_processor_config,
|
||||
get_hf_text_config, get_pooling_config,
|
||||
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
|
||||
try_get_generation_config, uses_mrope)
|
||||
try_get_generation_config, try_get_safetensors_metadata, uses_mrope)
|
||||
from vllm.transformers_utils.s3_utils import S3Model
|
||||
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
|
||||
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
|
||||
LayerBlockType, cuda_device_count_stateless,
|
||||
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
|
||||
random_uuid, resolve_obj_by_qualname)
|
||||
LayerBlockType, common_broadcastable_dtype,
|
||||
cuda_device_count_stateless, get_cpu_memory,
|
||||
get_open_port, is_torch_equal_or_newer, random_uuid,
|
||||
resolve_obj_by_qualname)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
@@ -540,7 +542,24 @@ class ModelConfig:
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self.dtype = _get_and_verify_dtype(
|
||||
self.model,
|
||||
self.hf_config,
|
||||
self.dtype,
|
||||
is_pooling_model=self.runner_type == "pooling",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Workaround for Gemma 2 which uses interleaved sliding window
|
||||
# attention, but it's not specified in its config. TODO: remove this
|
||||
@@ -597,16 +616,6 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
"`override_neuron_config` is only supported on Neuron.")
|
||||
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
@@ -692,7 +701,6 @@ class ModelConfig:
|
||||
self.model, self.revision)
|
||||
|
||||
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||
|
||||
if self.runner_type == "pooling":
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.override_pooler_config = PoolerConfig(
|
||||
@@ -3074,13 +3082,37 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE: list[str] = [] #
|
||||
# model_type -> reason
|
||||
_FLOAT16_NOT_SUPPORTED_MODELS = {
|
||||
"gemma2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"gemma3": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"plamo2": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
"glm4": "Numerical instability. Please use bfloat16 or float32 instead.",
|
||||
}
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
def _is_valid_dtype(model_type: str, dtype: torch.dtype):
|
||||
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16: # noqa: E501, SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _check_valid_dtype(model_type: str, dtype: torch.dtype):
|
||||
if model_type in _FLOAT16_NOT_SUPPORTED_MODELS and dtype == torch.float16:
|
||||
reason = _FLOAT16_NOT_SUPPORTED_MODELS[model_type]
|
||||
raise ValueError(f"The model type {model_type!r} "
|
||||
f"does not support float16. Reason: {reason}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _find_dtype(
|
||||
model_id: str,
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
*,
|
||||
revision: Optional[str],
|
||||
):
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
@@ -3092,75 +3124,111 @@ def _get_and_verify_dtype(
|
||||
if config_dtype is None and hasattr(config, "vision_config"):
|
||||
config_dtype = getattr(config.vision_config, "torch_dtype", None)
|
||||
|
||||
# Try to read the dtype of the weights if they are in safetensors format
|
||||
if config_dtype is None:
|
||||
repo_mt = try_get_safetensors_metadata(model_id, revision=revision)
|
||||
|
||||
if repo_mt and (files_mt := repo_mt.files_metadata):
|
||||
param_dtypes: set[torch.dtype] = {
|
||||
_SAFETENSORS_TO_TORCH_DTYPE[dtype_str]
|
||||
for file_mt in files_mt.values()
|
||||
for dtype_str in file_mt.parameter_count
|
||||
if dtype_str in _SAFETENSORS_TO_TORCH_DTYPE
|
||||
}
|
||||
|
||||
if param_dtypes:
|
||||
return common_broadcastable_dtype(param_dtypes)
|
||||
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
return config_dtype
|
||||
|
||||
|
||||
def _resolve_auto_dtype(
|
||||
model_type: str,
|
||||
config_dtype: torch.dtype,
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
supported_dtypes = [
|
||||
dtype for dtype in current_platform.supported_dtypes
|
||||
if _is_valid_dtype(model_type, dtype)
|
||||
]
|
||||
|
||||
if is_pooling_model and torch.float16 in supported_dtypes:
|
||||
preferred_dtype = torch.float16
|
||||
else:
|
||||
preferred_dtype = supported_dtypes[0]
|
||||
|
||||
# Downcast for float32 models
|
||||
if config_dtype == torch.float32:
|
||||
config_dtype = preferred_dtype
|
||||
|
||||
if config_dtype in supported_dtypes:
|
||||
return config_dtype
|
||||
|
||||
# Ensure device compatibility
|
||||
device_name = current_platform.get_device_name()
|
||||
device_capability = current_platform.get_device_capability()
|
||||
|
||||
if device_capability is None:
|
||||
device_str = f"{device_name!r}"
|
||||
else:
|
||||
version_str = device_capability.as_version_str()
|
||||
device_str = f"{device_name!r} (with compute capability {version_str})"
|
||||
|
||||
logger.warning(
|
||||
"Your device %s doesn't support %s. "
|
||||
"Falling back to %s for compatibility.",
|
||||
device_str,
|
||||
config_dtype,
|
||||
preferred_dtype,
|
||||
)
|
||||
|
||||
return preferred_dtype
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
model_id: str,
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
*,
|
||||
is_pooling_model: bool,
|
||||
revision: Optional[str] = None,
|
||||
) -> torch.dtype:
|
||||
config_dtype = _find_dtype(model_id, config, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
# Set default dtype from model config
|
||||
if config_dtype == torch.float32:
|
||||
# Following common practice, we use float16 for float32 models
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
|
||||
if config.model_type == "plamo2":
|
||||
logger.warning(
|
||||
"For PLaMo2, we cast models to bfloat16 instead of using "
|
||||
"float16 by default. This is because float16 does not work."
|
||||
)
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
# Deal with torch dtype fallback for device compatibility.
|
||||
from vllm.platforms import current_platform
|
||||
if torch_dtype not in current_platform.supported_dtypes:
|
||||
device_name = current_platform.get_device_name()
|
||||
|
||||
if ((capability := current_platform.get_device_capability())
|
||||
is None):
|
||||
compute_str = ""
|
||||
else:
|
||||
version_str = capability.as_version_str()
|
||||
compute_str = f" (with compute capability {version_str})"
|
||||
fallback_dtype = current_platform.supported_dtypes[0]
|
||||
logger.warning(
|
||||
"Your %s device%s doesn't support %s. " \
|
||||
"Falling back to %s for compatibility.",
|
||||
device_name, compute_str, torch_dtype, fallback_dtype
|
||||
)
|
||||
torch_dtype = fallback_dtype
|
||||
|
||||
if current_platform.is_hpu() and torch_dtype == torch.float16:
|
||||
logger.warning(
|
||||
"For HPU, we cast models to bfloat16 instead of "
|
||||
"using float16 by default. Please specify `dtype` if you "
|
||||
"want to use float16.")
|
||||
torch_dtype = torch.bfloat16
|
||||
elif dtype == "float16" and config.model_type == "plamo2":
|
||||
logger.warning(
|
||||
"For PLaMo2, using float16 is unstable and might cause "
|
||||
"unexpected behavior. Please use bfloat16 or float32 instead.")
|
||||
torch_dtype = torch.float16
|
||||
torch_dtype = _resolve_auto_dtype(
|
||||
model_type,
|
||||
config_dtype,
|
||||
is_pooling_model=is_pooling_model,
|
||||
)
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
raise ValueError(f"Unknown dtype: {dtype!r}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
# Verify the dtype.
|
||||
_check_valid_dtype(model_type, torch_dtype)
|
||||
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
|
||||
|
||||
Reference in New Issue
Block a user