[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import random
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -63,7 +63,7 @@ def test_copy_blocks(
|
||||
src_blocks = random.sample(range(num_blocks), num_mappings)
|
||||
remainig_blocks = list(set(range(num_blocks)) - set(src_blocks))
|
||||
dst_blocks = random.sample(remainig_blocks, 2 * num_mappings)
|
||||
block_mapping = []
|
||||
block_mapping: List[Tuple[int, int]] = []
|
||||
for i in range(num_mappings):
|
||||
src = src_blocks[i]
|
||||
dst1 = dst_blocks[2 * i]
|
||||
@@ -131,8 +131,8 @@ def test_reshape_and_cache(
|
||||
torch.set_default_device(device)
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long)
|
||||
|
||||
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
@@ -170,12 +170,12 @@ def test_reshape_and_cache(
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_indicies = block_indicies.cpu().tolist()
|
||||
block_indicies_lst = block_indicies.cpu().tolist()
|
||||
block_offsets = slot_mapping % block_size
|
||||
block_offsets = block_offsets.cpu().tolist()
|
||||
block_offsets_lst = block_offsets.cpu().tolist()
|
||||
for i in range(num_tokens):
|
||||
block_idx = block_indicies[i]
|
||||
block_offset = block_offsets[i]
|
||||
block_idx = block_indicies_lst[i]
|
||||
block_offset = block_offsets_lst[i]
|
||||
cloned_key_cache[block_idx, :, :, block_offset, :] = reshaped_key[i]
|
||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||
|
||||
@@ -224,8 +224,10 @@ def test_reshape_and_cache_flash(
|
||||
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=device)
|
||||
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping_lst,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
@@ -257,13 +259,13 @@ def test_reshape_and_cache_flash(
|
||||
slot_mapping, kv_cache_dtype)
|
||||
|
||||
# Run the reference implementation.
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||
block_indicies = block_indicies.cpu().tolist()
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_indicies_lst = block_indicies.cpu().tolist()
|
||||
block_offsets = slot_mapping % block_size
|
||||
block_offsets = block_offsets.cpu().tolist()
|
||||
block_offsets_lst = block_offsets.cpu().tolist()
|
||||
for i in range(num_tokens):
|
||||
block_idx = block_indicies[i]
|
||||
block_offset = block_offsets[i]
|
||||
block_idx = block_indicies_lst[i]
|
||||
block_offset = block_offsets_lst[i]
|
||||
cloned_key_cache[block_idx, block_offset, :, :] = key[i]
|
||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user