Make various updates and fixes (#198)
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
@@ -11,8 +12,28 @@ class DeviceRuntime {
|
||||
int num_sms = 0, tc_util = 0;
|
||||
std::shared_ptr<cudaDeviceProp> cached_prop;
|
||||
|
||||
// cuBLASLt utils
|
||||
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
|
||||
cublasLtHandle_t cublaslt_handle{};
|
||||
std::shared_ptr<torch::Tensor> cublaslt_workspace;
|
||||
|
||||
public:
|
||||
explicit DeviceRuntime() = default;
|
||||
explicit DeviceRuntime() {
|
||||
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));
|
||||
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
|
||||
}
|
||||
|
||||
~DeviceRuntime() noexcept(false) {
|
||||
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
|
||||
}
|
||||
|
||||
cublasLtHandle_t get_cublaslt_handle() const {
|
||||
return cublaslt_handle;
|
||||
}
|
||||
|
||||
torch::Tensor get_cublaslt_workspace() const {
|
||||
return *cublaslt_workspace;
|
||||
}
|
||||
|
||||
std::shared_ptr<cudaDeviceProp> get_prop() {
|
||||
if (cached_prop == nullptr)
|
||||
|
||||
Reference in New Issue
Block a user