Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -19,30 +19,33 @@ NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 8192] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
USE_KEY = [True, False]
|
||||
|
||||
|
||||
def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
def _get_flat_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads * head_size)
|
||||
|
||||
|
||||
# For testing sliced tensors
|
||||
def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
def _get_padded_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size + 64)
|
||||
|
||||
|
||||
def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
def _get_batch_tensor_shape(
|
||||
batch_size: int, seq_len: int, num_heads: int, head_size: int
|
||||
) -> tuple[int, ...]:
|
||||
return (batch_size, seq_len, num_heads, head_size)
|
||||
|
||||
|
||||
TENSORS_SHAPES_FN = [
|
||||
_get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape
|
||||
_get_batch_tensor_shape,
|
||||
_get_flat_tensor_shape,
|
||||
_get_padded_tensor_shape,
|
||||
]
|
||||
|
||||
|
||||
@@ -97,41 +100,63 @@ def test_rotary_embedding(
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
out_query, out_key = rope.forward(positions, query, key)
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(out_query,
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query))
|
||||
torch.testing.assert_close(
|
||||
out_query,
|
||||
ref_query,
|
||||
atol=get_default_atol(out_query),
|
||||
rtol=get_default_rtol(out_query),
|
||||
)
|
||||
if use_key:
|
||||
torch.testing.assert_close(out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
torch.testing.assert_close(
|
||||
out_key,
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key),
|
||||
)
|
||||
else:
|
||||
assert ref_key is None and out_key is None, \
|
||||
"expected returned key to be None"
|
||||
assert ref_key is None and out_key is None, "expected returned key to be None"
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
BASES = [10000, 1000000]
|
||||
ROPE_SCALINGS = (None, {
|
||||
"rope_type": "linear",
|
||||
"factor": (1, )
|
||||
}, {
|
||||
"rope_type": "dynamic",
|
||||
"factor": 1
|
||||
})
|
||||
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS, DTYPES)
|
||||
ROPE_SCALINGS = (
|
||||
None,
|
||||
{"rope_type": "linear", "factor": (1,)},
|
||||
{"rope_type": "dynamic", "factor": 1},
|
||||
)
|
||||
settings = (
|
||||
HEAD_SIZES,
|
||||
ROTARY_DIMS,
|
||||
MAX_POSITIONS,
|
||||
BASES,
|
||||
IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS,
|
||||
DTYPES,
|
||||
)
|
||||
rope_setting_id_map: dict[str, int] = {}
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_stype, rope_scaling, dtype)
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
)
|
||||
# different settings cannot share the same rope module
|
||||
assert id(rope) not in rope_setting_id_map.values()
|
||||
assert all(x.dtype == dtype for x in rope.buffers())
|
||||
@@ -139,11 +164,25 @@ def test_rope_module_cache():
|
||||
rope_setting_id_map[str(setting)] = id(rope)
|
||||
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
) = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_stype, rope_scaling, dtype)
|
||||
rope = get_rope(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_stype,
|
||||
rope_scaling,
|
||||
dtype,
|
||||
)
|
||||
# check if cache take effect
|
||||
assert id(rope) == rope_setting_id_map[str(setting)]
|
||||
|
||||
Reference in New Issue
Block a user