[Misc] Standardize RoPE handling for Qwen2-VL (#9250)
This commit is contained in:
@@ -105,7 +105,7 @@ def test_batched_rotary_embedding(
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||
"type": "linear",
|
||||
"rope_type": "linear",
|
||||
"factor": (1, )
|
||||
})
|
||||
rope = rope.to(dtype=dtype)
|
||||
@@ -166,7 +166,7 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
rotary_dim = head_size
|
||||
scaling_factors: List[int] = [1, 2, 4]
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, {
|
||||
"type": "linear",
|
||||
"rope_type": "linear",
|
||||
"factor": tuple(scaling_factors)
|
||||
})
|
||||
rope = rope.to(dtype=dtype)
|
||||
@@ -211,10 +211,10 @@ def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
BASES = [10000, 1000000]
|
||||
ROPE_SCALINGS = (None, {
|
||||
"type": "linear",
|
||||
"rope_type": "linear",
|
||||
"factor": (1, )
|
||||
}, {
|
||||
"type": "dynamic",
|
||||
"rope_type": "dynamic",
|
||||
"factor": 1
|
||||
})
|
||||
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||
|
||||
Reference in New Issue
Block a user