Add CUDA graph-based all reduce launcher (#26)
This commit is contained in:
@@ -27,6 +27,7 @@ class Controller:
|
||||
dtype: str,
|
||||
seed: int,
|
||||
model_path: str,
|
||||
max_num_batched_tokens: int,
|
||||
) -> None:
|
||||
self.stage_id = stage_id
|
||||
self.stage_devices = stage_devices
|
||||
@@ -57,6 +58,7 @@ class Controller:
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
model_path=model_path,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user