[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)

Signed-off-by: Hanzhi Zhou <hanzhi713@gmail.com>
This commit is contained in:
Hanzhi Zhou
2024-11-06 23:50:47 -08:00
committed by GitHub
parent d7263a1bb8
commit 6192e9b8fe
9 changed files with 218 additions and 260 deletions

View File

@@ -411,27 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar.def(
"init_custom_ar(Tensor meta, Tensor rank_data, "
"str[] handles, int[] offsets, int rank, "
"bool full_nvlink) -> int");
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full_nvlink) -> int");
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
custom_ar.def(
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
"()");
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
"int reg_buffer_sz_bytes) -> ()");
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
custom_ar.def("dispose", &dispose);
custom_ar.def("meta_size", &meta_size);
custom_ar.def(
"register_buffer(int fa, Tensor t, str[] handles, "
"int[] offsets) -> ()");
custom_ar.impl("register_buffer", torch::kCUDA, &register_buffer);
custom_ar.def("register_buffer", &register_buffer);
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
custom_ar.def("register_graph_buffers", &register_graph_buffers);
}