[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)

Signed-off-by: Hanzhi Zhou <hanzhi713@gmail.com>
This commit is contained in:
Hanzhi Zhou
2024-11-06 23:50:47 -08:00
committed by GitHub
parent d7263a1bb8
commit 6192e9b8fe
9 changed files with 218 additions and 260 deletions

View File

@@ -196,8 +196,8 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name
def is_custom_ar_all_reduce_unreg(op_name: str):
return "_C_custom_ar::all_reduce_unreg" in op_name
def is_custom_ar_all_reduce(op_name: str):
return "_C_custom_ar::all_reduce" in op_name
def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name
@@ -246,9 +246,9 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))
custom_ar_all_reduce_unreg_ops = list(
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))
custom_ar_all_reduce_ops = list(
filter(lambda x: is_custom_ar_all_reduce(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_ops, ops))
reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))
@@ -289,21 +289,21 @@ def group_trace_by_operations(trace_df: pd.DataFrame) -> pd.DataFrame:
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_unreg_ops):
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_ops):
trace_df['custom_ar_all_reduce_ops'] = trace_df[
custom_ar_all_reduce_ops].agg("sum", axis=1)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)
trace_df.drop(
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
vocab_embed_ops + mem_ops + elementwise_ops +
nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops +
nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
return trace_df