Add CUDA graph-based all reduce launcher (#26)

This commit is contained in:
Woosuk Kwon
2023-04-05 11:16:57 -07:00
committed by GitHub
parent 21b3671bbc
commit 12659a0bd7
7 changed files with 103 additions and 16 deletions

View File

@@ -22,7 +22,7 @@ class Server:
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
max_num_batched_tokens: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
@@ -43,7 +43,7 @@ class Server:
tensor_parallel_size=tensor_parallel_size,
)
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_batch_size)
max_num_batched_tokens=max_num_batched_tokens)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space=swap_space)
print(f'# GPU blocks: {self.num_gpu_blocks}, '
@@ -66,6 +66,7 @@ class Server:
dtype=dtype,
seed=seed,
model_path=model_path,
max_num_batched_tokens=max_num_batched_tokens,
)
self.controllers.append(controller)
@@ -75,7 +76,7 @@ class Server:
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
max_num_batched_tokens=max_batch_size,
max_num_batched_tokens=max_num_batched_tokens,
)
# Connect the controllers.
for i in range(len(self.controllers) - 1):
@@ -168,8 +169,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
help='model path to download and load the weights')
# Parallel arguments
parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
@@ -177,5 +178,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
return parser