Support FP32 (#141)
This commit is contained in:
@@ -164,7 +164,7 @@ def _get_and_verify_dtype(
|
||||
config_dtype = torch.float32
|
||||
|
||||
dtype = dtype.lower()
|
||||
if dtype == "default":
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32 models.
|
||||
torch_dtype = torch.float16
|
||||
@@ -184,9 +184,8 @@ def _get_and_verify_dtype(
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is not allowed.
|
||||
raise ValueError(
|
||||
f"Cannot use {torch_dtype} for {config_dtype} model.")
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warn(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
# Check if the GPU supports the dtype.
|
||||
if torch_dtype == torch.bfloat16:
|
||||
|
||||
Reference in New Issue
Block a user