[mypy] Enable type checking for test directory (#5017)

This commit is contained in:
Cyrus Leung
2024-06-15 12:45:31 +08:00
committed by GitHub
parent 1b8a0d71cf
commit 0e9164b40a
92 changed files with 509 additions and 378 deletions

View File

@@ -25,7 +25,7 @@ def ref_paged_attn(
block_tables = block_tables.cpu().numpy()
_, block_size, num_kv_heads, head_size = key_cache.shape
outputs = []
outputs: List[torch.Tensor] = []
start_idx = 0
for i in range(num_seqs):
query_len = query_lens[i]
@@ -70,7 +70,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES)
@torch.inference_mode
def test_flash_attn_with_paged_kv(
kv_lens: List[Tuple[int, int]],
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,