[Frontend] Dynamic RoPE scaling (#4638)
This commit is contained in:
@@ -2,9 +2,12 @@ from typing import Dict, Optional
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
JAISConfig, MPTConfig, RWConfig)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
||||
"chatglm": ChatGLMConfig,
|
||||
"dbrx": DbrxConfig,
|
||||
@@ -18,7 +21,8 @@ _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = {
|
||||
def get_config(model: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None) -> PretrainedConfig:
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None) -> PretrainedConfig:
|
||||
try:
|
||||
config = AutoConfig.from_pretrained(
|
||||
model,
|
||||
@@ -41,6 +45,10 @@ def get_config(model: str,
|
||||
config = config_class.from_pretrained(model,
|
||||
revision=revision,
|
||||
code_revision=code_revision)
|
||||
if rope_scaling is not None:
|
||||
logger.info("Updating rope_scaling from %r to %r",
|
||||
getattr(config, "rope_scaling", None), rope_scaling)
|
||||
config.update({"rope_scaling": rope_scaling})
|
||||
return config
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user