Support block size 32 (#35)
This commit is contained in:
@@ -350,7 +350,7 @@ def test_attention(seed: int) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
@@ -368,7 +368,7 @@ def test_attention(seed: int) -> None:
|
||||
# note that the test is also more likely to fail due to the much
|
||||
# larger amount of tokens in the input may increase the variance.
|
||||
for dtype in [torch.half, torch.float]:
|
||||
for block_size in [8, 16]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
print(f'Testing multi_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
|
||||
Reference in New Issue
Block a user