Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-11-19 18:06:36 +01:00
committed by GitHub
parent d44e9df7d4
commit a8b70304d6
104 changed files with 542 additions and 910 deletions

View File

@@ -26,23 +26,23 @@ def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool = True,
rope_scaling: dict[str, Any] | None = None,
rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if rope_scaling is not None:
if rope_parameters is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
rope_parameters_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_parameters.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
rope_parameters_args = tuple(rope_parameters_tuple.items())
else:
rope_scaling_args = None
rope_parameters_args = None
if dual_chunk_attention_config is not None:
dual_chunk_attention_tuple = {
@@ -60,15 +60,15 @@ def get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling_args,
rope_parameters_args,
dual_chunk_attention_args,
dtype,
)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
base = rope_parameters["rope_theta"] if rope_parameters else 10000
if dual_chunk_attention_config is not None:
extra_kwargs = {
k: v
@@ -84,18 +84,18 @@ def get_rope(
dtype,
**extra_kwargs,
)
elif not rope_scaling:
elif not rope_parameters:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
else:
scaling_type = rope_scaling["rope_type"]
scaling_type = rope_parameters["rope_type"]
if scaling_type == "llama3":
scaling_factor = rope_scaling["factor"]
low_freq_factor = rope_scaling["low_freq_factor"]
high_freq_factor = rope_scaling["high_freq_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
scaling_factor = rope_parameters["factor"]
low_freq_factor = rope_parameters["low_freq_factor"]
high_freq_factor = rope_parameters["high_freq_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
rotary_emb = Llama3RotaryEmbedding(
head_size,
rotary_dim,
@@ -113,7 +113,7 @@ def get_rope(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
)
elif scaling_type == "default":
if "mrope_section" in rope_scaling:
if "mrope_section" in rope_parameters:
rotary_emb = MRotaryEmbedding(
head_size,
rotary_dim,
@@ -121,8 +121,8 @@ def get_rope(
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
)
else:
rotary_emb = RotaryEmbedding(
@@ -134,7 +134,7 @@ def get_rope(
dtype,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
scaling_factor = rope_parameters["factor"]
rotary_emb = LinearScalingRotaryEmbedding(
head_size,
rotary_dim,
@@ -145,8 +145,8 @@ def get_rope(
dtype,
)
elif scaling_type == "ntk":
scaling_factor = rope_scaling["factor"]
mixed_b = rope_scaling.get("mixed_b", None)
scaling_factor = rope_parameters["factor"]
mixed_b = rope_parameters.get("mixed_b")
rotary_emb = NTKScalingRotaryEmbedding(
head_size,
rotary_dim,
@@ -158,8 +158,8 @@ def get_rope(
mixed_b,
)
elif scaling_type == "dynamic":
if "alpha" in rope_scaling:
scaling_alpha = rope_scaling["alpha"]
if "alpha" in rope_parameters:
scaling_alpha = rope_parameters["alpha"]
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
head_size,
rotary_dim,
@@ -169,8 +169,8 @@ def get_rope(
scaling_alpha,
dtype,
)
elif "factor" in rope_scaling:
scaling_factor = rope_scaling["factor"]
elif "factor" in rope_parameters:
scaling_factor = rope_parameters["factor"]
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
@@ -185,11 +185,11 @@ def get_rope(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
@@ -199,7 +199,7 @@ def get_rope(
"apply_yarn_scaling",
)
}
if "mrope_section" in rope_scaling:
if "mrope_section" in rope_parameters:
extra_kwargs.pop("apply_yarn_scaling", None)
rotary_emb = MRotaryEmbedding(
head_size,
@@ -208,8 +208,8 @@ def get_rope(
base,
is_neox_style,
dtype,
mrope_section=rope_scaling["mrope_section"],
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
mrope_section=rope_parameters["mrope_section"],
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
scaling_factor=scaling_factor,
**extra_kwargs,
)
@@ -225,12 +225,12 @@ def get_rope(
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
scaling_factor = rope_scaling["factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
scaling_factor = rope_parameters["factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
for k, v in rope_parameters.items()
if k
in (
"extrapolation_factor",
@@ -252,12 +252,12 @@ def get_rope(
**extra_kwargs,
)
elif scaling_type == "longrope":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling["original_max_position_embeddings"]
short_factor = rope_parameters["short_factor"]
long_factor = rope_parameters["long_factor"]
original_max_position = rope_parameters["original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
for k, v in rope_parameters.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(