Add llama 4 scaling support (#28145)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2025-11-06 19:55:17 +01:00
committed by GitHub
parent 5e0c1fe69c
commit 7a8375f8a0
4 changed files with 59 additions and 8 deletions

View File

@@ -24,6 +24,18 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
if bool(config_dict.get("yarn")):
config_dict = _remap_mistral_yarn_args(config_dict)
if bool(config_dict.get("llama_4_scaling")):
llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
assert all(
[
key in config_dict["llama_4_scaling"]
for key in llama_4_scaling_config_keys
]
), (
"llama_4_scaling config should define the keys: "
f"{','.join(llama_4_scaling_config_keys)}"
)
is_vision = (config_dict.get("multimodal") or {}).get(
"vision_encoder_args"
) or config_dict.get("vision_encoder")
@@ -66,19 +78,24 @@ def _remap_mistral_vision_args(config: dict) -> dict:
def _remap_mistral_yarn_args(config: dict) -> dict:
# Direct remaps: yarn.X -> rope_scaling.Y
# Source keys are from mistral.model.args.YarnArgs
_map = {
yarn_config_map = {
"factor": "factor",
"original_max_position_embeddings": "original_max_position_embeddings",
"beta": "beta_fast",
"alpha": "beta_slow",
"apply_scale": "apply_yarn_scaling",
}
yarn_config = config.get("yarn") or {}
renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()}
config["rope_scaling"] = {
"rope_type": "yarn",
"mscale_all_dim": 1, # We hardcoded this to 1
**renamed_yarn_config,
"mscale_all_dim": 1,
}
for old_name, new_name in yarn_config_map.items():
if old_name in yarn_config:
config["rope_scaling"][new_name] = yarn_config.pop(old_name)
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
return config