Add CUDA graph-based all reduce launcher (#26)

This commit is contained in:
Woosuk Kwon
2023-04-05 11:16:57 -07:00
committed by GitHub
parent 21b3671bbc
commit 12659a0bd7
7 changed files with 103 additions and 16 deletions

View File

@@ -27,6 +27,7 @@ class Controller:
dtype: str,
seed: int,
model_path: str,
max_num_batched_tokens: int,
) -> None:
self.stage_id = stage_id
self.stage_devices = stage_devices
@@ -57,6 +58,7 @@ class Controller:
tensor_parallel_size=tensor_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
model_path=model_path,
max_num_batched_tokens=max_num_batched_tokens,
)
self.workers.append(worker)