[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -25,7 +25,7 @@ def ref_paged_attn(
|
||||
block_tables = block_tables.cpu().numpy()
|
||||
_, block_size, num_kv_heads, head_size = key_cache.shape
|
||||
|
||||
outputs = []
|
||||
outputs: List[torch.Tensor] = []
|
||||
start_idx = 0
|
||||
for i in range(num_seqs):
|
||||
query_len = query_lens[i]
|
||||
@@ -70,7 +70,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@torch.inference_mode
|
||||
def test_flash_attn_with_paged_kv(
|
||||
kv_lens: List[Tuple[int, int]],
|
||||
kv_lens: List[int],
|
||||
num_heads: Tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
|
||||
Reference in New Issue
Block a user