Support starcoder2 architecture (#3089)
This commit is contained in:
@@ -9,6 +9,7 @@ _CONFIG_REGISTRY = {
|
||||
"mpt": MPTConfig,
|
||||
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
|
||||
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
|
||||
"starcoder2": Starcoder2Config,
|
||||
}
|
||||
|
||||
|
||||
@@ -16,6 +17,15 @@ def get_config(model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None) -> PretrainedConfig:
|
||||
# FIXME(woosuk): This is a temporary fix for StarCoder2.
|
||||
# Remove this when the model is supported by HuggingFace transformers.
|
||||
if "bigcode" in model and "starcoder2" in model:
|
||||
config_class = _CONFIG_REGISTRY["starcoder2"]
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
return config
|
||||
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
|
||||
Reference in New Issue
Block a user