[Bugfix] fix rope error when load models with different dtypes (#4835)

This commit is contained in:
Jinzhen Lin
2024-05-17 17:43:34 +08:00
committed by GitHub
parent 26148120b3
commit 33e0823de5
2 changed files with 64 additions and 13 deletions

View File

@@ -1,4 +1,4 @@
from itertools import accumulate
from itertools import accumulate, product
from typing import List, Optional
import pytest
@@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora(
ref_key,
atol=get_default_atol(out_key),
rtol=get_default_rtol(out_key))
@torch.inference_mode()
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = [
None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
}
]
settings = [
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES
]
rope_setting_id_map = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# different settings cannot share the same rope module
assert id(rope) not in rope_setting_id_map.values()
assert all(x.dtype == dtype for x in rope.buffers())
assert all(x.dtype == dtype for x in rope.parameters())
rope_setting_id_map[str(setting)] = id(rope)
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base,
is_neox_stype, rope_scaling, dtype)
# check if cache take effect
assert id(rope) == rope_setting_id_map[str(setting)]