[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -1,5 +1,5 @@
from itertools import accumulate, product
from typing import List, Optional
from typing import Dict, List, Optional
import pytest
import torch
@@ -126,7 +126,7 @@ def test_batched_rotary_embedding(
query,
key,
offsets=torch.zeros(batch_size * seq_len,
dtype=int,
dtype=torch.long,
device=device))
# Compare the results.
assert torch.allclose(out_query,
@@ -214,20 +214,16 @@ def test_batched_rotary_embedding_multi_lora(
def test_rope_module_cache():
MAX_POSITIONS = [123, 1234]
BASES = [10000, 1000000]
ROPE_SCALINGS = [
None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
}
]
settings = [
HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES
]
rope_setting_id_map = {}
ROPE_SCALINGS = (None, {
"type": "linear",
"factor": (1, )
}, {
"type": "dynamic",
"factor": 1
})
settings = (HEAD_SIZES, ROTARY_DIMS, MAX_POSITIONS, BASES, IS_NEOX_STYLE,
ROPE_SCALINGS, DTYPES)
rope_setting_id_map: Dict[str, int] = {}
for setting in product(*settings):
head_size, rotary_dim, max_position, base, \
is_neox_stype, rope_scaling, dtype = setting