Access partial_rotary_factor from rope_parameters (#29966)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -30,7 +30,6 @@ def get_rope(
|
||||
is_neox_style: bool = True,
|
||||
rope_parameters: dict[str, Any] | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
partial_rotary_factor: float = 1.0,
|
||||
dual_chunk_attention_config: dict[str, Any] | None = None,
|
||||
) -> RotaryEmbedding:
|
||||
if dtype is None:
|
||||
@@ -55,6 +54,10 @@ def get_rope(
|
||||
else:
|
||||
dual_chunk_attention_args = None
|
||||
|
||||
partial_rotary_factor = 1.0
|
||||
if rope_parameters is not None:
|
||||
partial_rotary_factor = rope_parameters.get("partial_rotary_factor", 1.0)
|
||||
|
||||
if partial_rotary_factor < 1.0:
|
||||
rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
key = (
|
||||
|
||||
Reference in New Issue
Block a user