Add llama 4 scaling support (#28145)
Signed-off-by: Julien Denize <julien.denize@mistral.ai>
This commit is contained in:
@@ -191,9 +191,16 @@ def get_rope(
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k
|
||||
in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow")
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
"attn_factor",
|
||||
"beta_fast",
|
||||
"beta_slow",
|
||||
"apply_yarn_scaling",
|
||||
)
|
||||
}
|
||||
if "mrope_section" in rope_scaling:
|
||||
extra_kwargs.pop("apply_yarn_scaling", None)
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
|
||||
@@ -27,6 +27,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
attn_factor: float = 1,
|
||||
beta_fast: int = 32,
|
||||
beta_slow: int = 1,
|
||||
apply_yarn_scaling: bool = True,
|
||||
) -> None:
|
||||
self.scaling_factor = scaling_factor
|
||||
self.extrapolation_factor = extrapolation_factor
|
||||
@@ -34,7 +35,11 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
self.beta_fast = beta_fast
|
||||
self.beta_slow = beta_slow
|
||||
# Get n-d magnitude scaling corrected for interpolation
|
||||
self.mscale = float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
self.mscale = (
|
||||
float(yarn_get_mscale(self.scaling_factor) * attn_factor)
|
||||
if apply_yarn_scaling
|
||||
else float(attn_factor)
|
||||
)
|
||||
super().__init__(
|
||||
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user