[Frontend] Dynamic RoPE scaling (#4638)

This commit is contained in:
sasha0552
2024-05-22 05:32:35 +00:00
committed by GitHub
parent 99eff67ba9
commit 9b9a10d6cb
5 changed files with 89 additions and 12 deletions

View File

@@ -2,9 +2,12 @@ from typing import Dict, Optional
from transformers import AutoConfig, PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
JAISConfig, MPTConfig, RWConfig)
logger = init_logger(__name__)
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
"chatglm": ChatGLMConfig,
"dbrx": DbrxConfig,
@@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
code_revision: Optional[str] = None) -> PretrainedConfig:
code_revision: Optional[str] = None,
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model,
@@ -41,6 +45,10 @@ def get_config(model: str,
config = config_class.from_pretrained(model,
revision=revision,
code_revision=code_revision)
if rope_scaling is not None:
logger.info("Updating rope_scaling from %r to %r",
getattr(config, "rope_scaling", None), rope_scaling)
config.update({"rope_scaling": rope_scaling})
return config