Implement custom all reduce kernels (#2192)
This commit is contained in:
@@ -88,4 +88,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&get_max_shared_memory_per_block_device_attribute,
|
||||
"Gets the maximum shared memory per block device attribute.");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Custom all-reduce kernels
|
||||
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
|
||||
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
|
||||
custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
|
||||
custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
|
||||
custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
|
||||
custom_ar.def("dispose", &dispose, "dispose");
|
||||
custom_ar.def("meta_size", &meta_size, "meta_size");
|
||||
custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
|
||||
custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
|
||||
"get_graph_buffer_ipc_meta");
|
||||
custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
|
||||
"register_graph_buffers");
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user