Files
DeepGEMM/csrc/jit_kernels/impls/smxx_cublaslt.hpp
2025-09-25 16:19:07 +08:00

152 lines
8.4 KiB
C++

#pragma once
#include <cublasLt.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
namespace deep_gemm {
static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld,
const std::optional<int>& batch_count = std::nullopt,
const std::optional<int>& batch_offset = std::nullopt) {
cublasLtMatrixLayout_t layout;
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld));
if (batch_count.has_value()) {
DG_HOST_ASSERT(batch_offset.has_value());
const int64_t batch_offset_int64 = batch_offset.value();
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value())));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64)));
}
return layout;
}
static void call_cublaslt_api(const cublasOperation_t& trans_a,
const cublasOperation_t& trans_b,
const cublasLtMatrixLayout_t& layout_a,
const cublasLtMatrixLayout_t& layout_b,
const cublasLtMatrixLayout_t& layout_d,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const bool& accumulate) {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
cudaDataType_t scale_type = CUDA_R_32F;
const int& math_sms = device_runtime->get_num_sms();
bool fp8_fast_accumulate = false;
// Operation description
cublasLtMatmulDesc_t desc;
DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
if (a.scalar_type() == torch::kFloat8_e4m3fn)
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
// Get cuBLASLt handle, workspace, and stream
const auto& handle = device_runtime->get_cublaslt_handle();
const auto& workspace = device_runtime->get_cublaslt_workspace();
const auto& workspace_bytes = workspace.nbytes();
const auto& stream = at::cuda::getCurrentCUDAStream();
// Algorithm selection
cublasLtMatmulPreference_t pref;
cublasLtMatmulHeuristicResult_t heuristic;
int num_heuristic_results = 0;
uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_bytes, sizeof(workspace_bytes)));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
&reduction_scheme_mask, sizeof(reduction_scheme_mask)));
DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d,
pref, 1, &heuristic, &num_heuristic_results));
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
// Call: D = alpha * (A @ B) + beta * C
const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
desc, // Operation description
&alpha, // Alpha
b.data_ptr(), layout_a, // A
a.data_ptr(), layout_b, // B
&beta, // Beta
d.data_ptr(), layout_d, // C
d.data_ptr(), layout_d, // D
&heuristic.algo, // Algorithm
workspace.data_ptr(), workspace_bytes, // Workspace
stream)); // Stream
// Free memory
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d));
DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc));
}
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
const std::optional<torch::Tensor>& acc,
const torch::Tensor& out,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) {
const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
// Duplicate the accumulator if necessary
// TODO: remove this
if (acc.has_value()) {
if (acc->data_ptr() == out.data_ptr()) {
DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides());
} else {
out.copy_(acc.value());
}
}
// Matrix layouts
const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value());
}
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = d, n = b, k = r;
const auto& trans_a = CUBLAS_OP_T;
const auto& trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = r, n = b, k = d;
const auto& trans_a = CUBLAS_OP_N;
const auto& trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
} // namespace deep_gemm