[Bugfix] fix rope error when load models with different dtypes (#4835)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from itertools import accumulate
|
||||
from itertools import accumulate, product
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
@@ -207,3 +207,45 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
ref_key,
|
||||
atol=get_default_atol(out_key),
|
||||
rtol=get_default_rtol(out_key))
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
BASES = [10000, 1000000]
|
||||
ROPE_SCALINGS = [
|
||||
None, {
|
||||
"type": "linear",
|
||||
"factor": (1, )
|
||||
}, {
|
||||
"type": "dynamic",
|
||||
"factor": 1
|
||||
}
|
||||
]
|
||||
settings = [
|
||||
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS, DTYPES
|
||||
]
|
||||
rope_setting_id_map = {}
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_stype, rope_scaling, dtype)
|
||||
# different settings cannot share the same rope module
|
||||
assert id(rope) not in rope_setting_id_map.values()
|
||||
assert all(x.dtype == dtype for x in rope.buffers())
|
||||
assert all(x.dtype == dtype for x in rope.parameters())
|
||||
rope_setting_id_map[str(setting)] = id(rope)
|
||||
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_stype, rope_scaling, dtype)
|
||||
# check if cache take effect
|
||||
assert id(rope) == rope_setting_id_map[str(setting)]
|
||||
|
||||
Reference in New Issue
Block a user