Add CUDA graph-based all reduce launcher (#26)
This commit is contained in:
@@ -47,6 +47,7 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
|
||||
# Memory buffers to avoid dynamic memory allocation
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
@@ -205,6 +206,20 @@ def initialize_model_parallel(
|
||||
_set_global_memory_buffer()
|
||||
|
||||
|
||||
def initialize_all_reduce_launcher(
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
global _ALL_REDUCE_LAUNCHER
|
||||
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
||||
max_num_tokens=max_num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
disable_graph=disable_graph,
|
||||
)
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if model and data parallel groups are initialized."""
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||
@@ -491,6 +506,9 @@ def get_global_memory_buffer():
|
||||
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
||||
return _GLOBAL_MEMORY_BUFFER
|
||||
|
||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
||||
return _ALL_REDUCE_LAUNCHER
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none."""
|
||||
@@ -520,3 +538,56 @@ def destroy_model_parallel():
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
|
||||
class GraphAllReduce:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.disable_graph = disable_graph
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
if tp_world_size == 1:
|
||||
return
|
||||
|
||||
self.group = get_tensor_model_parallel_group()
|
||||
self.buffer = torch.empty(
|
||||
size=(max_num_tokens, hidden_size),
|
||||
dtype=dtype,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
# Build graphs for different number of tokens.
|
||||
if not self.disable_graph:
|
||||
self.graphs = {}
|
||||
for num_tokens in range(8, max_num_tokens + 1, 8):
|
||||
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
||||
|
||||
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
||||
# Warm up.
|
||||
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Build graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
torch.distributed.all_reduce(
|
||||
self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
return graph
|
||||
|
||||
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: x must be a slice of self.buffer.
|
||||
num_tokens = x.shape[0]
|
||||
if self.disable_graph:
|
||||
torch.distributed.all_reduce(x, group=self.group)
|
||||
else:
|
||||
self.graphs[num_tokens].replay()
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user