Add Model Revision Support (#1014)
Co-authored-by: Jasmond Loh <Jasmond.Loh@hotmail.com> Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
|
||||
@@ -12,10 +14,12 @@ _CONFIG_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
||||
def get_config(model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None) -> PretrainedConfig:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code)
|
||||
model, trust_remote_code=trust_remote_code, revision=revision)
|
||||
except ValueError as e:
|
||||
if (not trust_remote_code and
|
||||
"requires you to execute the configuration file" in str(e)):
|
||||
@@ -29,5 +33,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
|
||||
raise e
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
config = config_class.from_pretrained(model)
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user