[mypy] Enable type checking for test directory (#5017)
This commit is contained in:
@@ -72,27 +72,27 @@ def ref_single_query_cached_kv_attention(
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs = query.shape[0]
|
||||
|
||||
block_tables = block_tables.cpu().tolist()
|
||||
seq_lens = seq_lens.cpu().tolist()
|
||||
block_tables_lst = block_tables.cpu().tolist()
|
||||
seq_lens_lst = seq_lens.cpu().tolist()
|
||||
for i in range(num_seqs):
|
||||
q = query[i].unsqueeze(0)
|
||||
block_table = block_tables[i]
|
||||
seq_len = int(seq_lens[i])
|
||||
block_table = block_tables_lst[i]
|
||||
seq_len = int(seq_lens_lst[i])
|
||||
|
||||
keys = []
|
||||
values = []
|
||||
keys_lst: List[torch.Tensor] = []
|
||||
values_lst: List[torch.Tensor] = []
|
||||
for j in range(seq_len):
|
||||
block_number = int(block_table[j // block_size])
|
||||
block_offset = j % block_size
|
||||
|
||||
k = key_cache[block_number, :, :, block_offset, :]
|
||||
k = k.reshape(num_kv_heads, head_size)
|
||||
keys.append(k)
|
||||
keys_lst.append(k)
|
||||
|
||||
v = value_cache[block_number, :, :, block_offset]
|
||||
values.append(v)
|
||||
keys = torch.stack(keys, dim=0)
|
||||
values = torch.stack(values, dim=0)
|
||||
values_lst.append(v)
|
||||
keys = torch.stack(keys_lst, dim=0)
|
||||
values = torch.stack(values_lst, dim=0)
|
||||
if num_queries_per_kv > 1:
|
||||
# Handle MQA and GQA
|
||||
keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1)
|
||||
@@ -157,14 +157,15 @@ def test_paged_attention(
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
|
||||
block_tables = []
|
||||
block_tables_lst: List[List[int]] = []
|
||||
for _ in range(num_seqs):
|
||||
block_table = [
|
||||
random.randint(0, NUM_BLOCKS - 1)
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||
block_tables_lst.append(block_table)
|
||||
|
||||
block_tables = torch.tensor(block_tables_lst, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
@@ -283,7 +284,7 @@ def ref_multi_query_kv_attention(
|
||||
dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
num_seqs = len(cu_seq_lens) - 1
|
||||
ref_outputs = []
|
||||
ref_outputs: List[torch.Tensor] = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = cu_seq_lens[i]
|
||||
end_idx = cu_seq_lens[i + 1]
|
||||
@@ -303,8 +304,8 @@ def ref_multi_query_kv_attention(
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
ref_outputs.append(ref_output)
|
||||
ref_output = torch.cat(ref_outputs, dim=0)
|
||||
return ref_output
|
||||
|
||||
return torch.cat(ref_outputs, dim=0)
|
||||
|
||||
|
||||
# TODO(woosuk): Add tests for USE_ALIBI=True.
|
||||
|
||||
Reference in New Issue
Block a user