- Add support for legacy CUDA versions; now compatible with CUDA 12.3 and newer - Add support for NVRTC compilation - Other fixes and code refactoring
28 lines
544 B
C++
28 lines
544 B
C++
#pragma once
|
|
|
|
#include <functional>
|
|
#include <memory>
|
|
|
|
#define DG_DECLARE_STATIC_VAR_IN_CLASS(cls, name) decltype(cls::name) cls::name
|
|
|
|
namespace deep_gemm {
|
|
|
|
template <typename T>
|
|
class LazyInit {
|
|
public:
|
|
explicit LazyInit(std::function<std::shared_ptr<T>()> factory)
|
|
: factory(std::move(factory)) {}
|
|
|
|
T* operator -> () {
|
|
if (ptr == nullptr)
|
|
ptr = factory();
|
|
return ptr.get();
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<T> ptr;
|
|
std::function<std::shared_ptr<T>()> factory;
|
|
};
|
|
|
|
} // namespace deep_gemm
|