Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim (#30389)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -83,8 +83,12 @@ def test_rotary_embedding(
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
|
||||
rope_parameters = {
|
||||
"rope_type": "default",
|
||||
"rope_theta": rope_theta,
|
||||
"partial_rotary_factor": rotary_dim / head_size,
|
||||
}
|
||||
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
@@ -150,9 +154,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
@@ -177,9 +181,9 @@ def test_rope_module_cache():
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope_parameters["partial_rotary_factor"] = rotary_dim / head_size
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
|
||||
Reference in New Issue
Block a user