Make various updates and fixes (#198)

This commit is contained in:
Ray Wang
2025-09-25 16:19:07 +08:00
committed by GitHub
parent 79f48ee15a
commit 3f71de7aa9
45 changed files with 3281 additions and 1060 deletions

View File

@@ -1,5 +1,6 @@
#pragma once
#include <cublasLt.h>
#include <ATen/cuda/CUDAContext.h>
#include "../utils/exception.hpp"
@@ -11,8 +12,28 @@ class DeviceRuntime {
int num_sms = 0, tc_util = 0;
std::shared_ptr<cudaDeviceProp> cached_prop;
// cuBLASLt utils
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
cublasLtHandle_t cublaslt_handle{};
std::shared_ptr<torch::Tensor> cublaslt_workspace;
public:
explicit DeviceRuntime() = default;
explicit DeviceRuntime() {
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
}
~DeviceRuntime() noexcept(false) {
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
}
cublasLtHandle_t get_cublaslt_handle() const {
return cublaslt_handle;
}
torch::Tensor get_cublaslt_workspace() const {
return *cublaslt_workspace;
}
std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr)