Simplify weight loading logic (#2133)

This commit is contained in:
Roy
2023-12-17 04:41:23 +08:00
committed by GitHub
parent 2acd76f346
commit eed74a558f
3 changed files with 33 additions and 37 deletions

View File

@@ -122,15 +122,10 @@ class ModelConfig:
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures:
if load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"
if "MixtralForCausalLM" in architectures and load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
self.load_format = load_format
def _verify_tokenizer_mode(self) -> None: