[Model] Add support for GPT-J (#226)

Co-authored-by: woWoosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Andre Slavescu
2023-07-08 20:55:16 -04:00
committed by GitHub
parent 75beba29b5
commit c894836108
10 changed files with 269 additions and 7 deletions

View File

@@ -286,7 +286,7 @@ def test_single_query_cached_kv_attention() -> None:
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [64, 80, 96, 128]:
for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
@@ -304,7 +304,7 @@ def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for head_size in [64, 80, 96, 128]:
for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing multi_query_kv_attention with dtype={dtype}, '
f'head_size={head_size}')
run_multi_query_kv_attention(