Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-11-19 18:06:36 +01:00
committed by GitHub
parent d44e9df7d4
commit a8b70304d6
104 changed files with 542 additions and 910 deletions

View File

@@ -6,7 +6,7 @@
#
# The CSV file (named with current date/time) contains these columns:
# model_name, tp_size, num_tokens, num_heads, num_kv_heads, head_dim, max_position,
# rope_theta, is_neox_style, rope_scaling, dtype, torch_mean, torch_median, torch_p99,
# is_neox_style, rope_parameters, dtype, torch_mean, torch_median, torch_p99,
# torch_min, torch_max, triton_mean, triton_median, triton_p99, triton_min, triton_max,
# speedup
#
@@ -86,9 +86,8 @@ def benchmark_mrope(
num_heads: int,
num_kv_heads: int,
max_position: int = 8192,
rope_theta: float = 10000,
is_neox_style: bool = True,
rope_scaling: dict[str, Any] = None,
rope_parameters: dict[str, Any] | None = None,
dtype: torch.dtype = torch.bfloat16,
seed: int = 0,
warmup_iter: int = 10,
@@ -102,9 +101,8 @@ def benchmark_mrope(
head_size=head_dim,
rotary_dim=head_dim,
max_position=max_position,
base=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=rope_scaling,
rope_parameters=rope_parameters,
dtype=dtype,
).to(device=device)
@@ -203,9 +201,8 @@ def benchmark_mrope(
num_kv_heads,
head_dim,
max_position,
rope_theta,
is_neox_style,
str(rope_scaling),
str(rope_parameters),
str(dtype).split(".")[-1],
torch_stats["mean"],
torch_stats["median"],
@@ -255,9 +252,8 @@ if __name__ == "__main__":
"num_kv_heads",
"head_dim",
"max_position",
"rope_theta",
"is_neox_style",
"rope_scaling",
"rope_parameters",
"dtype",
"torch_mean",
"torch_median",
@@ -303,7 +299,7 @@ if __name__ == "__main__":
q_size = num_heads * head_dim
kv_size = num_kv_heads * head_dim
is_neox_style = True
rope_theta = config.rope_theta
rope_parameters = config.rope_parameters
max_position = config.max_position_embeddings
for num_tokens in num_tokens_list:
@@ -315,9 +311,8 @@ if __name__ == "__main__":
num_heads=num_heads,
num_kv_heads=num_kv_heads,
max_position=max_position,
rope_theta=rope_theta,
is_neox_style=is_neox_style,
rope_scaling=config.rope_scaling,
rope_parameters=rope_parameters,
dtype=getattr(torch, args.dtype),
seed=args.seed,
warmup_iter=args.warmup_iter,