Access partial_rotary_factor from rope_parameters (#29966)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-12-04 18:42:49 +00:00
committed by GitHub
parent ece2825a29
commit e10c84e06a
21 changed files with 43 additions and 62 deletions

View File

@@ -30,7 +30,6 @@ def get_rope(
is_neox_style: bool = True,
rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype | None = None,
partial_rotary_factor: float = 1.0,
dual_chunk_attention_config: dict[str, Any] | None = None,
) -> RotaryEmbedding:
if dtype is None:
@@ -55,6 +54,10 @@ def get_rope(
else:
dual_chunk_attention_args = None
partial_rotary_factor = 1.0
if rope_parameters is not None:
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
if partial_rotary_factor < 1.0:
rotary_dim = int(rotary_dim * partial_rotary_factor)
key = (