[Model] Add support for MPT (#334)

This commit is contained in:
Woosuk Kwon
2023-07-03 16:47:53 -07:00
committed by GitHub
parent 7717d0838b
commit 404422f42e
11 changed files with 388 additions and 4 deletions

View File

@@ -0,0 +1,15 @@
from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY = {
"mpt": MPTConfig,
}
def get_config(model: str) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)
return config