[Bugfix] Use 'sum' reduction instead of 'avg' in Async TP reduce-scatter (#33088)
Signed-off-by: Xingran Wang <wangxingran123456@outlook.com> Signed-off-by: Hongjian Zhang <hirokenovo@gmail.com> Co-authored-by: Hongjian Zhang <hirokenovo@gmail.com>
This commit is contained in:
@@ -53,7 +53,7 @@ class GEMMReduceScatterPattern(BasePattern):
|
|||||||
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
gemm_rs = torch.ops.symm_mem.fused_matmul_reduce_scatter(
|
||||||
mul,
|
mul,
|
||||||
mm_weight,
|
mm_weight,
|
||||||
"avg",
|
"sum",
|
||||||
scatter_dim=0,
|
scatter_dim=0,
|
||||||
group_name=self.tp.device_group.group_name,
|
group_name=self.tp.device_group.group_name,
|
||||||
)
|
)
|
||||||
@@ -150,7 +150,7 @@ class ScaledMMReduceScatterPattern(BasePattern):
|
|||||||
mat2,
|
mat2,
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
"avg",
|
"sum",
|
||||||
scatter_dim, # orig_scatter_dim
|
scatter_dim, # orig_scatter_dim
|
||||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||||
self.tp.device_group.group_name,
|
self.tp.device_group.group_name,
|
||||||
@@ -285,7 +285,7 @@ class CutlassScaledMMReduceScatterPattern(BasePattern):
|
|||||||
mat2,
|
mat2,
|
||||||
scale_a,
|
scale_a,
|
||||||
scale_b,
|
scale_b,
|
||||||
"avg",
|
"sum",
|
||||||
scatter_dim, # orig_scatter_dim
|
scatter_dim, # orig_scatter_dim
|
||||||
scatter_dim, # scatter_dim_after_maybe_reshape
|
scatter_dim, # scatter_dim_after_maybe_reshape
|
||||||
self.tp.device_group.group_name,
|
self.tp.device_group.group_name,
|
||||||
|
|||||||
Reference in New Issue
Block a user