Support FP32 (#141)

This commit is contained in:
Woosuk Kwon
2023-06-07 00:40:21 -07:00
committed by GitHub
parent 376725ce74
commit e38074b1e6
8 changed files with 65 additions and 54 deletions

View File

@@ -164,7 +164,7 @@ def _get_and_verify_dtype(
config_dtype = torch.float32
dtype = dtype.lower()
if dtype == "default":
if dtype == "auto":
if config_dtype == torch.float32:
# Following the common practice, we use float16 for float32 models.
torch_dtype = torch.float16
@@ -184,9 +184,8 @@ def _get_and_verify_dtype(
# Downcasting from float32 to float16 or bfloat16 is allowed.
pass
else:
# Casting between float16 and bfloat16 is not allowed.
raise ValueError(
f"Cannot use {torch_dtype} for {config_dtype} model.")
# Casting between float16 and bfloat16 is allowed with a warning.
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16: