Add CUDA graph-based all reduce launcher (#26)

This commit is contained in:
Woosuk Kwon
2023-04-05 11:16:57 -07:00
committed by GitHub
parent 21b3671bbc
commit 12659a0bd7
7 changed files with 103 additions and 16 deletions

View File

@@ -47,6 +47,7 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
@@ -205,6 +206,20 @@ def initialize_model_parallel(
_set_global_memory_buffer()
def initialize_all_reduce_launcher(
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
global _ALL_REDUCE_LAUNCHER
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
max_num_tokens=max_num_tokens,
hidden_size=hidden_size,
dtype=dtype,
disable_graph=disable_graph,
)
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
@@ -491,6 +506,9 @@ def get_global_memory_buffer():
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
return _GLOBAL_MEMORY_BUFFER
def get_all_reduce_launcher() -> 'GraphAllReduce':
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
return _ALL_REDUCE_LAUNCHER
def destroy_model_parallel():
"""Set the groups to none."""
@@ -520,3 +538,56 @@ def destroy_model_parallel():
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _GLOBAL_MEMORY_BUFFER
_GLOBAL_MEMORY_BUFFER = None
class GraphAllReduce:
def __init__(
self,
max_num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
disable_graph: bool = False,
) -> None:
self.max_num_tokens = max_num_tokens
self.hidden_size = hidden_size
self.disable_graph = disable_graph
tp_world_size = get_tensor_model_parallel_world_size()
if tp_world_size == 1:
return
self.group = get_tensor_model_parallel_group()
self.buffer = torch.empty(
size=(max_num_tokens, hidden_size),
dtype=dtype,
device='cuda',
)
# Build graphs for different number of tokens.
if not self.disable_graph:
self.graphs = {}
for num_tokens in range(8, max_num_tokens + 1, 8):
self.graphs[num_tokens] = self._build_graph(num_tokens)
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
# Warm up.
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
# Build graph.
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
torch.distributed.all_reduce(
self.buffer[:num_tokens], group=self.group)
torch.cuda.synchronize()
return graph
def launch(self, x: torch.Tensor) -> torch.Tensor:
# NOTE: x must be a slice of self.buffer.
num_tokens = x.shape[0]
if self.disable_graph:
torch.distributed.all_reduce(x, group=self.group)
else:
self.graphs[num_tokens].replay()
return x