Support block size 32 (#35)

This commit is contained in:
Woosuk Kwon
2023-04-09 23:07:18 -07:00
committed by GitHub
parent ee88a7e5f3
commit b9926f7f66
4 changed files with 49 additions and 5 deletions

View File

@@ -15,9 +15,9 @@ class BlockAllocator:
block_size: int,
num_blocks: int,
) -> None:
if block_size not in [8, 16]:
if block_size not in [8, 16, 32]:
raise ValueError(f'Unsupported block size: {block_size}'
'The block size must be either 8 or 16.')
'The block size must be one of {8, 16, 32}.')
self.device = device
self.block_size = block_size
self.num_blocks = num_blocks

View File

@@ -174,7 +174,7 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16, 32], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).