Support block size 32 (#35)

This commit is contained in:
Woosuk Kwon
2023-04-09 23:07:18 -07:00
committed by GitHub
parent ee88a7e5f3
commit b9926f7f66
4 changed files with 49 additions and 5 deletions

View File

@@ -350,7 +350,7 @@ def test_attention(seed: int) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for block_size in [8, 16, 32]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
@@ -368,7 +368,7 @@ def test_attention(seed: int) -> None:
# note that the test is also more likely to fail due to the much
# larger amount of tokens in the input may increase the variance.
for dtype in [torch.half, torch.float]:
for block_size in [8, 16]:
for block_size in [8, 16, 32]:
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
print(f'Testing multi_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '