Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -26,23 +26,23 @@ def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
rope_parameters: dict[str, Any] | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: dict[str, Any] | None = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if rope_scaling is not None:
|
||||
if rope_parameters is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items()
|
||||
rope_parameters_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in rope_parameters.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
rope_parameters_args = tuple(rope_parameters_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
rope_parameters_args = None
|
||||
|
||||
if dual_chunk_attention_config is not None:
|
||||
dual_chunk_attention_tuple = {
|
||||
@@ -60,15 +60,15 @@ def get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
rope_scaling_args,
|
||||
rope_parameters_args,
|
||||
dual_chunk_attention_args,
|
||||
dtype,
|
||||
)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
base = rope_parameters["rope_theta"] if rope_parameters else 10000
|
||||
if dual_chunk_attention_config is not None:
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
@@ -84,18 +84,18 @@ def get_rope(
|
||||
dtype,
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif not rope_scaling:
|
||||
elif not rope_parameters:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
else:
|
||||
scaling_type = rope_scaling["rope_type"]
|
||||
scaling_type = rope_parameters["rope_type"]
|
||||
|
||||
if scaling_type == "llama3":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
low_freq_factor = rope_parameters["low_freq_factor"]
|
||||
high_freq_factor = rope_parameters["high_freq_factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
rotary_emb = Llama3RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
@@ -113,7 +113,7 @@ def get_rope(
|
||||
head_size, rotary_dim, max_position, base, is_neox_style, dtype
|
||||
)
|
||||
elif scaling_type == "default":
|
||||
if "mrope_section" in rope_scaling:
|
||||
if "mrope_section" in rope_parameters:
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
@@ -121,8 +121,8 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
|
||||
mrope_section=rope_parameters["mrope_section"],
|
||||
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
|
||||
)
|
||||
else:
|
||||
rotary_emb = RotaryEmbedding(
|
||||
@@ -134,7 +134,7 @@ def get_rope(
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "linear":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
rotary_emb = LinearScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
@@ -145,8 +145,8 @@ def get_rope(
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "ntk":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mixed_b = rope_scaling.get("mixed_b", None)
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
mixed_b = rope_parameters.get("mixed_b")
|
||||
rotary_emb = NTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
@@ -158,8 +158,8 @@ def get_rope(
|
||||
mixed_b,
|
||||
)
|
||||
elif scaling_type == "dynamic":
|
||||
if "alpha" in rope_scaling:
|
||||
scaling_alpha = rope_scaling["alpha"]
|
||||
if "alpha" in rope_parameters:
|
||||
scaling_alpha = rope_parameters["alpha"]
|
||||
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
@@ -169,8 +169,8 @@ def get_rope(
|
||||
scaling_alpha,
|
||||
dtype,
|
||||
)
|
||||
elif "factor" in rope_scaling:
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
elif "factor" in rope_parameters:
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
@@ -185,11 +185,11 @@ def get_rope(
|
||||
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
|
||||
)
|
||||
elif scaling_type == "yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
for k, v in rope_parameters.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
@@ -199,7 +199,7 @@ def get_rope(
|
||||
"apply_yarn_scaling",
|
||||
)
|
||||
}
|
||||
if "mrope_section" in rope_scaling:
|
||||
if "mrope_section" in rope_parameters:
|
||||
extra_kwargs.pop("apply_yarn_scaling", None)
|
||||
rotary_emb = MRotaryEmbedding(
|
||||
head_size,
|
||||
@@ -208,8 +208,8 @@ def get_rope(
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
mrope_section=rope_scaling["mrope_section"],
|
||||
mrope_interleaved=rope_scaling.get("mrope_interleaved", False),
|
||||
mrope_section=rope_parameters["mrope_section"],
|
||||
mrope_interleaved=rope_parameters.get("mrope_interleaved", False),
|
||||
scaling_factor=scaling_factor,
|
||||
**extra_kwargs,
|
||||
)
|
||||
@@ -225,12 +225,12 @@ def get_rope(
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "deepseek_yarn":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
# assert max_position == original_max_position * scaling_factor
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
for k, v in rope_parameters.items()
|
||||
if k
|
||||
in (
|
||||
"extrapolation_factor",
|
||||
@@ -252,12 +252,12 @@ def get_rope(
|
||||
**extra_kwargs,
|
||||
)
|
||||
elif scaling_type == "longrope":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling["original_max_position_embeddings"]
|
||||
short_factor = rope_parameters["short_factor"]
|
||||
long_factor = rope_parameters["long_factor"]
|
||||
original_max_position = rope_parameters["original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
for k, v in rope_parameters.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
|
||||
Reference in New Issue
Block a user