Add CUDA graph-based all reduce launcher (#26)
This commit is contained in:
@@ -22,7 +22,7 @@ class Server:
|
||||
dtype: str,
|
||||
seed: int,
|
||||
swap_space: int,
|
||||
max_batch_size: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_nodes: int,
|
||||
num_devices_per_node: int,
|
||||
distributed_init_method: str,
|
||||
@@ -43,7 +43,7 @@ class Server:
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
|
||||
max_num_batched_tokens=max_batch_size)
|
||||
max_num_batched_tokens=max_num_batched_tokens)
|
||||
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
|
||||
swap_space=swap_space)
|
||||
print(f'# GPU blocks: {self.num_gpu_blocks}, '
|
||||
@@ -66,6 +66,7 @@ class Server:
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
model_path=model_path,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
self.controllers.append(controller)
|
||||
|
||||
@@ -75,7 +76,7 @@ class Server:
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=self.num_gpu_blocks,
|
||||
num_cpu_blocks=self.num_cpu_blocks,
|
||||
max_num_batched_tokens=max_batch_size,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
# Connect the controllers.
|
||||
for i in range(len(self.controllers) - 1):
|
||||
@@ -168,8 +169,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
||||
help='model path to download and load the weights')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
@@ -177,5 +178,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
|
||||
return parser
|
||||
|
||||
Reference in New Issue
Block a user