[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)
Signed-off-by: Hanzhi Zhou <hanzhi713@gmail.com>
This commit is contained in:
@@ -411,27 +411,18 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) {
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
// Custom all-reduce kernels
|
||||
custom_ar.def(
|
||||
"init_custom_ar(Tensor meta, Tensor rank_data, "
|
||||
"str[] handles, int[] offsets, int rank, "
|
||||
"bool full_nvlink) -> int");
|
||||
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
|
||||
"int rank, bool full_nvlink) -> int");
|
||||
custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
|
||||
|
||||
custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg);
|
||||
|
||||
custom_ar.def(
|
||||
"all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> "
|
||||
"()");
|
||||
custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg);
|
||||
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
|
||||
"int reg_buffer_sz_bytes) -> ()");
|
||||
custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
custom_ar.def("dispose", &dispose);
|
||||
custom_ar.def("meta_size", &meta_size);
|
||||
|
||||
custom_ar.def(
|
||||
"register_buffer(int fa, Tensor t, str[] handles, "
|
||||
"int[] offsets) -> ()");
|
||||
custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer);
|
||||
|
||||
custom_ar.def("register_buffer", ®ister_buffer);
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user