Add CUDA graph-based all reduce launcher (#26)
This commit is contained in:
@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_all_reduce_launcher,
|
||||
)
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
@@ -407,8 +408,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
@@ -425,11 +425,18 @@ class RowParallelLinear(torch.nn.Module):
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
# Matrix multiply.
|
||||
output_ = F.linear(input_parallel, self.weight)
|
||||
else:
|
||||
# Matrix multiply.
|
||||
all_reduce_launcher = get_all_reduce_launcher()
|
||||
num_tokens = input_parallel.shape[0]
|
||||
output_buffer = all_reduce_launcher.buffer[:num_tokens]
|
||||
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
|
||||
# All-reduce across all the partitions.
|
||||
output_ = all_reduce_launcher.launch(output_buffer)
|
||||
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
|
||||
Reference in New Issue
Block a user