Add CUDA graph-based all reduce launcher (#26)

This commit is contained in:
Woosuk Kwon
2023-04-05 11:16:57 -07:00
committed by GitHub
parent 21b3671bbc
commit 12659a0bd7
7 changed files with 103 additions and 16 deletions

View File

@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_all_reduce_launcher,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
@@ -407,8 +408,7 @@ class RowParallelLinear(torch.nn.Module):
self.bias.zero_()
else:
self.register_parameter('bias', None)
self.weight_t = self.weight.t()
def forward(self, input_):
"""Forward of RowParallelLinear
@@ -425,11 +425,18 @@ class RowParallelLinear(torch.nn.Module):
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
if get_tensor_model_parallel_world_size() == 1:
# Matrix multiply.
output_ = F.linear(input_parallel, self.weight)
else:
# Matrix multiply.
all_reduce_launcher = get_all_reduce_launcher()
num_tokens = input_parallel.shape[0]
output_buffer = all_reduce_launcher.buffer[:num_tokens]
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
# All-reduce across all the partitions.
output_ = all_reduce_launcher.launch(output_buffer)
# All-reduce across all the partitions.
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None