Add trust_remote_code arg to get_config (#405)

This commit is contained in:
Woosuk Kwon
2023-07-08 15:24:17 -07:00
committed by GitHub
parent b6fbb9a565
commit ddfdf470ae
3 changed files with 20 additions and 6 deletions

View File

@@ -7,8 +7,21 @@ _CONFIG_REGISTRY = {
}
def get_config(model: str) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)