#pragma once #include #include "../utils/exception.hpp" #include "../utils/lazy_init.hpp" namespace deep_gemm { class DeviceRuntime { int num_sms = 0, tc_util = 0; std::shared_ptr cached_prop; public: explicit DeviceRuntime() = default; std::shared_ptr get_prop() { if (cached_prop == nullptr) cached_prop = std::make_shared(*at::cuda::getCurrentDeviceProperties()); return cached_prop; } std::pair get_arch_pair() { const auto prop = get_prop(); return {prop->major, prop->minor}; } int get_arch() { const auto& [major, minor] = get_arch_pair(); return major * 10 + minor; } int get_arch_major() { return get_arch_pair().first; } void set_num_sms(const int& new_num_sms) { DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount); num_sms = new_num_sms; } int get_num_sms() { if (num_sms == 0) num_sms = get_prop()->multiProcessorCount; return num_sms; } void set_tc_util(const int& new_tc_util) { DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100); tc_util = new_tc_util; } int get_tc_util() const { return tc_util == 0 ? 100 : tc_util; } }; static auto device_runtime = LazyInit([](){ return std::make_shared(); }); } // namespace deep_gemm