Access partial_rotary_factor from rope_parameters (#29966)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-12-04 18:42:49 +00:00
committed by GitHub
parent ece2825a29
commit e10c84e06a
21 changed files with 43 additions and 62 deletions

View File

@@ -89,9 +89,14 @@ class NemotronConfig(PretrainedConfig):
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_parameters (`dict`, *optional*):
The parameters of the RoPE embeddings.
partial_rotary_factor (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding.
The parameters of the RoPE embeddings. Expected contents:
`rope_theta` (`float`): The base period of the RoPE embeddings.
`rope_type` (`str`):
The sub-variant of RoPE to use. Can be one of ['default', 'linear',
'dynamic', 'yarn', 'longrope', 'llama3'], with 'default' being the
original RoPE implementation.
`partial_rotary_factor` (`float`, *optional*, defaults to 0.5):
Percentage of the query and keys which will have rotary embedding.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output
projection layers during self-attention.
@@ -133,7 +138,6 @@ class NemotronConfig(PretrainedConfig):
eos_token_id=3,
tie_word_embeddings=False,
rope_parameters=None,
partial_rotary_factor=0.5,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
@@ -165,14 +169,16 @@ class NemotronConfig(PretrainedConfig):
rope_theta = kwargs.pop("rope_theta", 10000.0)
if "rope_theta" not in rope_parameters:
rope_parameters["rope_theta"] = rope_theta
self.rope_parameters = rope_parameters
# for backward compatibility
partial_rotary_factor = (
kwargs.get("rope_percent")
or kwargs.get("rope_percentage")
or partial_rotary_factor
or kwargs.get("partial_rotary_factor")
or 0.5
)
self.partial_rotary_factor = partial_rotary_factor
if "partial_rotary_factor" not in rope_parameters:
rope_parameters["partial_rotary_factor"] = partial_rotary_factor
self.rope_parameters = rope_parameters
self._rope_parameters_validation()
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout