Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim (#30389)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-12-11 20:45:23 +00:00
committed by GitHub
parent 92fea56fd1
commit cf3eacfe58
83 changed files with 260 additions and 314 deletions

View File

@@ -32,8 +32,8 @@ def get_benchmark(head_size, rotary_dim, is_neox_style, device):
def benchmark(batch_size, seq_len, num_heads, provider):
dtype = torch.bfloat16
max_position = 8192
base = 10000
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope_parameters = {"partial_rotary_factor": rotary_dim / head_size}
rope = get_rope(head_size, max_position, is_neox_style, rope_parameters)
rope = rope.to(dtype=dtype, device=device)
cos_sin_cache = rope.cos_sin_cache.to(dtype=torch.float, device=device)