Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -74,7 +74,7 @@ def test_rotary_embedding(
|
||||
device: str,
|
||||
use_key: bool,
|
||||
max_position: int = 8192,
|
||||
base: float = 10000,
|
||||
rope_theta: float = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
@@ -83,7 +83,8 @@ def test_rotary_embedding(
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
rope = get_rope(head_size, rotary_dim, max_position, is_neox_style, rope_parameters)
|
||||
rope = rope.to(dtype=dtype, device=torch.get_default_device())
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
@@ -120,9 +121,9 @@ def test_rotary_embedding(
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
BASES = [10000, 1000000]
|
||||
ROPE_SCALINGS = (
|
||||
None,
|
||||
ROPE_THETAS = [10000, 1000000]
|
||||
ROPE_PARAMETERS = (
|
||||
{"rope_type": "default"},
|
||||
{"rope_type": "linear", "factor": (1,)},
|
||||
{"rope_type": "dynamic", "factor": 1},
|
||||
)
|
||||
@@ -130,9 +131,9 @@ def test_rope_module_cache():
|
||||
HEAD_SIZES,
|
||||
ROTARY_DIMS,
|
||||
MAX_POSITIONS,
|
||||
BASES,
|
||||
ROPE_THETAS,
|
||||
IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS,
|
||||
ROPE_PARAMETERS,
|
||||
DTYPES,
|
||||
)
|
||||
rope_setting_id_map: dict[str, int] = {}
|
||||
@@ -141,20 +142,20 @@ def test_rope_module_cache():
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
)
|
||||
# different settings cannot share the same rope module
|
||||
@@ -168,20 +169,20 @@ def test_rope_module_cache():
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
rope_theta,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
is_neox_style,
|
||||
rope_parameters,
|
||||
dtype,
|
||||
)
|
||||
# check if cache take effect
|
||||
|
||||
Reference in New Issue
Block a user