Add llama 4 scaling support (#28145)

Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
Julien Denize
2025-11-06 19:55:17 +01:00
committed by GitHub
parent 5e0c1fe69c
commit 7a8375f8a0
4 changed files with 59 additions and 8 deletions

View File

@@ -191,9 +191,16 @@ def get_rope(
k: v
for k, v in rope_scaling.items()
if k
in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
in (
"extrapolation_factor",
"attn_factor",
"beta_fast",
"beta_slow",
"apply_yarn_scaling",
)
}
if "mrope_section" in rope_scaling:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,

View File

@@ -27,6 +27,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
apply_yarn_scaling: bool = True,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
@@ -34,7 +35,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
self.mscale = (
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
if apply_yarn_scaling
else float(attn_factor)
)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
)