[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from itertools import accumulate, product
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
|
||||
query,
|
||||
key,
|
||||
offsets=torch.zeros(batch_size * seq_len,
|
||||
dtype=int,
|
||||
dtype=torch.long,
|
||||
device=device))
|
||||
# Compare the results.
|
||||
assert torch.allclose(out_query,
|
||||
@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
|
||||
def test_rope_module_cache():
|
||||
MAX_POSITIONS = [123, 1234]
|
||||
BASES = [10000, 1000000]
|
||||
ROPE_SCALINGS = [
|
||||
None, {
|
||||
"type": "linear",
|
||||
"factor": (1, )
|
||||
}, {
|
||||
"type": "dynamic",
|
||||
"factor": 1
|
||||
}
|
||||
]
|
||||
settings = [
|
||||
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS, DTYPES
|
||||
]
|
||||
rope_setting_id_map = {}
|
||||
ROPE_SCALINGS = (None, {
|
||||
"type": "linear",
|
||||
"factor": (1, )
|
||||
}, {
|
||||
"type": "dynamic",
|
||||
"factor": 1
|
||||
})
|
||||
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
|
||||
ROPE_SCALINGS, DTYPES)
|
||||
rope_setting_id_map: Dict[str, int] = {}
|
||||
for setting in product(*settings):
|
||||
head_size, rotary_dim, max_position, base, \
|
||||
is_neox_stype, rope_scaling, dtype = setting
|
||||
|
||||
Reference in New Issue
Block a user