[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)
Signed-off-by: Hanzhi Zhou <hanzhi713@gmail.com>
This commit is contained in:
@@ -196,8 +196,8 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
def is_cross_device_reduce_2stage(op_name: str):
|
||||
return "cross_device_reduce_2stage" in op_name
|
||||
|
||||
def is_custom_ar_all_reduce_unreg(op_name: str):
|
||||
return "_C_custom_ar::all_reduce_unreg" in op_name
|
||||
def is_custom_ar_all_reduce(op_name: str):
|
||||
return "_C_custom_ar::all_reduce" in op_name
|
||||
|
||||
def is_reduce_kernel(op_name: str):
|
||||
return "reduce_kernel" in op_name
|
||||
@@ -246,9 +246,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
|
||||
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
|
||||
|
||||
custom_ar_all_reduce_unreg_ops = list(
|
||||
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
|
||||
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
|
||||
custom_ar_all_reduce_ops = list(
|
||||
filter(lambda x: is_custom_ar_all_reduce(x), ops))
|
||||
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
|
||||
|
||||
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
|
||||
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
|
||||
@@ -289,21 +289,21 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
|
||||
if len(cross_device_reduce_2stage_ops):
|
||||
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
|
||||
cross_device_reduce_2stage_ops].agg("sum", axis=1)
|
||||
if len(custom_ar_all_reduce_unreg_ops):
|
||||
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
|
||||
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
|
||||
if len(custom_ar_all_reduce_ops):
|
||||
trace_df['custom_ar_all_reduce_ops'] = trace_df[
|
||||
custom_ar_all_reduce_ops].agg("sum", axis=1)
|
||||
if len(reduce_kernel_ops):
|
||||
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
|
||||
axis=1)
|
||||
|
||||
trace_df.drop(
|
||||
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
|
||||
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
|
||||
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
|
||||
reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True)
|
||||
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
|
||||
vocab_embed_ops + mem_ops + elementwise_ops +
|
||||
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
|
||||
nccl_other_ops + cross_device_reduce_1stage_ops +
|
||||
cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
|
||||
reduce_kernel_ops,
|
||||
axis=1,
|
||||
inplace=True)
|
||||
return trace_df
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user