* Merge with private repo * Update README * Update README * Update README * Add PyTorch requirements * Fix sync scopes for MQA logits (#256) * Update README
52 lines
1.8 KiB
C++
52 lines
1.8 KiB
C++
#pragma once
|
|
|
|
#if DG_TENSORMAP_COMPATIBLE
|
|
#include "../jit/compiler.hpp"
|
|
#endif
|
|
#include "../jit/device_runtime.hpp"
|
|
#include "../jit_kernels/heuristics/runtime.hpp"
|
|
|
|
namespace deep_gemm::runtime {
|
|
|
|
static void register_apis(pybind11::module_& m) {
|
|
m.def("set_num_sms", [&](const int& new_num_sms) {
|
|
device_runtime->set_num_sms(new_num_sms);
|
|
});
|
|
m.def("get_num_sms", [&]() {
|
|
return device_runtime->get_num_sms();
|
|
});
|
|
m.def("set_tc_util", [&](const int& new_tc_util) {
|
|
device_runtime->set_tc_util(new_tc_util);
|
|
});
|
|
m.def("get_tc_util", [&]() {
|
|
return device_runtime->get_tc_util();
|
|
});
|
|
m.def("set_pdl", [&](const bool& new_enable_pdl) {
|
|
device_runtime->set_pdl(new_enable_pdl);
|
|
});
|
|
m.def("get_pdl", [&]() {
|
|
return device_runtime->get_pdl();
|
|
});
|
|
m.def("set_ignore_compile_dims", [&](const bool& new_value) {
|
|
heuristics_runtime->set_ignore_compile_dims(new_value);
|
|
});
|
|
m.def("set_block_size_multiple_of", [&](const std::variant<int, std::tuple<int, int>>& new_value) {
|
|
if (std::holds_alternative<int>(new_value)) {
|
|
auto x = std::get<int>(new_value);
|
|
heuristics_runtime->set_block_size_multiple_of(x, x);
|
|
} else {
|
|
auto [x, y] = std::get<std::tuple<int, int>>(new_value);
|
|
heuristics_runtime->set_block_size_multiple_of(x, y);
|
|
}
|
|
});
|
|
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
|
|
#if DG_TENSORMAP_COMPATIBLE
|
|
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
|
|
KernelRuntime::prepare_init(cuda_home_path_by_python);
|
|
IncludeParser::prepare_init(library_root_path);
|
|
#endif
|
|
});
|
|
}
|
|
|
|
} // namespace deep_gemm::runtime
|