- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
29 lines
857 B
C++
29 lines
857 B
C++
#pragma once
|
|
|
|
#include "../jit/compiler.hpp"
|
|
#include "../jit/device_runtime.hpp"
|
|
|
|
namespace deep_gemm::runtime {
|
|
|
|
static void register_apis(pybind11::module_& m) {
|
|
m.def("set_num_sms", [&](const int& new_num_sms) {
|
|
device_runtime->set_num_sms(new_num_sms);
|
|
});
|
|
m.def("get_num_sms", [&]() {
|
|
return device_runtime->get_num_sms();
|
|
});
|
|
m.def("set_tc_util", [&](const int& new_tc_util) {
|
|
device_runtime->set_tc_util(new_tc_util);
|
|
});
|
|
m.def("get_tc_util", [&]() {
|
|
return device_runtime->get_tc_util();
|
|
});
|
|
|
|
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
|
|
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
|
|
KernelRuntime::prepare_init(cuda_home_path_by_python);
|
|
});
|
|
}
|
|
|
|
} // namespace deep_gemm::runtime
|