Support FP32 (#141)
This commit is contained in:
@@ -270,9 +270,9 @@ def run_multi_query_kv_attention(
|
||||
def test_single_query_cached_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for block_size in [8, 16, 32, 64]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for block_size in [8, 16, 32]:
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
print(f'Testing single_query_cached_kv_attention with '
|
||||
f'dtype={dtype}, block_size={block_size}, '
|
||||
f'head_size={head_size}')
|
||||
@@ -289,8 +289,8 @@ def test_single_query_cached_kv_attention() -> None:
|
||||
def test_multi_query_kv_attention() -> None:
|
||||
torch.random.manual_seed(TEST_SEED)
|
||||
torch.cuda.manual_seed(TEST_SEED)
|
||||
for dtype in [torch.half, torch.bfloat16]:
|
||||
for head_size in [32, 64, 80, 96, 128, 160, 192, 256]:
|
||||
for dtype in [torch.half, torch.bfloat16, torch.float]:
|
||||
for head_size in [64, 80, 96, 128]:
|
||||
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
|
||||
f'head_size={head_size}')
|
||||
run_multi_query_kv_attention(
|
||||
|
||||
Reference in New Issue
Block a user