Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -88,8 +88,8 @@ class NemotronConfig(PretrainedConfig):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_parameters (`dict`, *optional*):
|
||||
The parameters of the RoPE embeddings.
|
||||
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
|
||||
Percentage of the query and keys which will have rotary embedding.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
@@ -132,8 +132,7 @@ class NemotronConfig(PretrainedConfig):
|
||||
bos_token_id=2,
|
||||
eos_token_id=3,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
rope_parameters=None,
|
||||
partial_rotary_factor=0.5,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
@@ -160,8 +159,13 @@ class NemotronConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
self.norm_eps = norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
|
||||
rope_scaling = kwargs.pop("rope_scaling", None)
|
||||
rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
|
||||
rope_theta = kwargs.pop("rope_theta", 10000.0)
|
||||
if "rope_theta" not in rope_parameters:
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
self.rope_parameters = rope_parameters
|
||||
# for backward compatibility
|
||||
partial_rotary_factor = (
|
||||
kwargs.get("rope_percent")
|
||||
@@ -169,7 +173,7 @@ class NemotronConfig(PretrainedConfig):
|
||||
or partial_rotary_factor
|
||||
)
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self._rope_scaling_validation()
|
||||
self._rope_parameters_validation()
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.mlp_bias = mlp_bias
|
||||
@@ -182,31 +186,29 @@ class NemotronConfig(PretrainedConfig):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
def _rope_parameters_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
Validate the `rope_parameters` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
if self.rope_parameters is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
rope_type: str | None = self.rope_parameters.get("rope_type", None)
|
||||
factor: float | None = self.rope_parameters.get("factor", None)
|
||||
|
||||
if rope_type not in {"default", "linear", "dynamic"}:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with two fields, "
|
||||
f"`type` and `factor`, got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s type field must be one of ['linear', "
|
||||
f"'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if (
|
||||
rope_scaling_factor is None
|
||||
or not isinstance(rope_scaling_factor, float)
|
||||
or rope_scaling_factor <= 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
"`rope_scaling`'s factor field must be a float > 1, got "
|
||||
f"{rope_scaling_factor}"
|
||||
"`rope_type` must be one of ['default', 'linear', 'dynamic'], "
|
||||
f"got {rope_type}"
|
||||
)
|
||||
if rope_type != "default":
|
||||
if factor is None:
|
||||
raise ValueError(
|
||||
"If `rope_type` is not 'default', `rope_parameters` "
|
||||
"must include a `factor` field. Got `None`."
|
||||
)
|
||||
if not isinstance(factor, float) or factor <= 1.0:
|
||||
raise ValueError(
|
||||
"`rope_parameters`'s factor field must be a float > 1, got "
|
||||
f"{factor}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user