Make various updates and fixes (#198)
This commit is contained in:
77
csrc/apis/attention.hpp
Normal file
77
csrc/apis/attention.hpp
Normal file
@@ -0,0 +1,77 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
|
||||
#include "layout.hpp"
|
||||
|
||||
namespace deep_gemm::attention {
|
||||
|
||||
static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::tuple<int, int, int> &head_splits,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
if (fp8_requires_k_major()) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
}
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check head splits and N
|
||||
const auto& [left, mid, right] = head_splits;
|
||||
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
|
||||
py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::attention
|
||||
115
csrc/apis/einsum.hpp
Normal file
115
csrc/apis/einsum.hpp
Normal file
@@ -0,0 +1,115 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
#include "../utils/format.hpp"
|
||||
#include "../utils/layout.hpp"
|
||||
|
||||
#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp"
|
||||
#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp"
|
||||
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
|
||||
|
||||
namespace deep_gemm::einsum {
|
||||
|
||||
static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c) {
|
||||
// Currently FP32 only support the accumulated expression
|
||||
if (d.scalar_type() == torch::kFloat) {
|
||||
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
} else {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
|
||||
DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(),
|
||||
c10::cuda::getCurrentCUDAStream()));
|
||||
bmk_bnk_mn(a, b, workspace, workspace);
|
||||
|
||||
// This line has an implicit FP32-to-BF16 casting
|
||||
d.copy_(workspace);
|
||||
return;
|
||||
}
|
||||
|
||||
DG_HOST_ASSERT(a.is_contiguous());
|
||||
DG_HOST_ASSERT(b.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
|
||||
const auto& [s , m, k ] = get_shape<3>(a);
|
||||
const auto& [s_, n, k_] = get_shape<3>(b);
|
||||
DG_HOST_ASSERT(s == s_ and k == k_);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
|
||||
} else if (arch_major == 10) {
|
||||
sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
|
||||
const auto& [b , h , r ] = get_shape<3>(A);
|
||||
const auto& [h_, d , r_] = get_shape<3>(B);
|
||||
const auto& [b_, h__, d_] = get_shape<3>(D);
|
||||
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
|
||||
|
||||
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
|
||||
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
|
||||
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
|
||||
|
||||
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
|
||||
}
|
||||
|
||||
static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D) {
|
||||
const auto& [b , h , d ] = get_shape<3>(A);
|
||||
const auto& [h_, d_ , r ] = get_shape<3>(B);
|
||||
const auto& [b_, h__, r_] = get_shape<3>(D);
|
||||
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
|
||||
|
||||
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
|
||||
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
|
||||
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
|
||||
|
||||
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
|
||||
}
|
||||
|
||||
static void einsum(const std::string& expr,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c) {
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c->scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Some hardcoded Einstein sum kernels
|
||||
// TODO: support any expression
|
||||
// TODO: canonicalize expression
|
||||
if (expr == "bmk,bnk->mn") {
|
||||
bmk_bnk_mn(a, b, d, c);
|
||||
} else if (expr == "bhr,hdr->bhd") {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
bhr_hdr_bhd(a, b, d);
|
||||
} else if (expr == "bhd,hdr->bhr") {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
bhd_hdr_bhr(a, b, d);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
|
||||
}
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("einsum", &einsum,
|
||||
py::arg("expr"), py::arg("a"), py::arg("b"),
|
||||
py::arg("d"), py::arg("c") = std::nullopt);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::einsum
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
@@ -52,13 +53,18 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
if (std::get<1>(recipe.value()) == 1) {
|
||||
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
}
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
@@ -261,6 +267,60 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
|
||||
}
|
||||
}
|
||||
|
||||
static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Must be 1D1D kernel
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
|
||||
// Shape checks
|
||||
const auto& [num_groups, m, n] = get_shape<3>(d);
|
||||
const auto& sum_mk = a.first.numel();
|
||||
const auto& sum_nk = b.first.numel();
|
||||
int sum_k = 0;
|
||||
for (const auto& k: ks)
|
||||
sum_k += k;
|
||||
DG_HOST_ASSERT(sum_mk == m * sum_k);
|
||||
DG_HOST_ASSERT(sum_nk == n * sum_k);
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
DG_HOST_ASSERT(b.first.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().is_contiguous());
|
||||
}
|
||||
|
||||
// Do nothing if empty
|
||||
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
|
||||
return;
|
||||
|
||||
// Transform SF with padding
|
||||
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
||||
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
||||
|
||||
// Allocate tensormap buffer
|
||||
// `4` means the double buffering for both A and B operands (2 * 2)
|
||||
const auto& num_sms = device_runtime->get_num_sms();
|
||||
const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
|
||||
a.first.options().dtype(torch::kByte));
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
|
||||
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bf16_gemm_nt(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
@@ -403,6 +463,43 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
|
||||
}
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a);
|
||||
const auto& [n , k_] = get_shape<2>(b);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
|
||||
if (c.has_value())
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == d.scalar_type());
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0 or n == 0)
|
||||
return;
|
||||
|
||||
cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b);
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
cublaslt_gemm_nt(a, b.transpose(0, 1), d, c);
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c);
|
||||
}
|
||||
|
||||
static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
||||
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
// FP8 GEMMs
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
@@ -442,6 +539,11 @@ static void register_apis(pybind11::module_& m) {
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
|
||||
// BF16 GEMMs
|
||||
m.def("bf16_gemm_nt", &bf16_gemm_nt,
|
||||
@@ -466,6 +568,16 @@ static void register_apis(pybind11::module_& m) {
|
||||
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
|
||||
|
||||
// cuBLASLt GEMMs
|
||||
m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::gemm
|
||||
|
||||
@@ -56,14 +56,14 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
|
||||
|
||||
// FP32 on SM90
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// FP32 on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
|
||||
|
||||
// INT on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
if (sf.scalar_type() == torch::kInt and arch_major == 10)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
// GEMM kernels
|
||||
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
|
||||
// Einsum kernels
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
// Layout kernels
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
@@ -11,8 +12,28 @@ class DeviceRuntime {
|
||||
int num_sms = 0, tc_util = 0;
|
||||
std::shared_ptr<cudaDeviceProp> cached_prop;
|
||||
|
||||
// cuBLASLt utils
|
||||
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
|
||||
cublasLtHandle_t cublaslt_handle{};
|
||||
std::shared_ptr<torch::Tensor> cublaslt_workspace;
|
||||
|
||||
public:
|
||||
explicit DeviceRuntime() = default;
|
||||
explicit DeviceRuntime() {
|
||||
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));
|
||||
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
|
||||
}
|
||||
|
||||
~DeviceRuntime() noexcept(false) {
|
||||
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
|
||||
}
|
||||
|
||||
cublasLtHandle_t get_cublaslt_handle() const {
|
||||
return cublaslt_handle;
|
||||
}
|
||||
|
||||
torch::Tensor get_cublaslt_workspace() const {
|
||||
return *cublaslt_workspace;
|
||||
}
|
||||
|
||||
std::shared_ptr<cudaDeviceProp> get_prop() {
|
||||
if (cached_prop == nullptr)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <deep_gemm/common/types.hpp>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/layout.hpp"
|
||||
|
||||
@@ -80,18 +82,19 @@ static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
|
||||
return divisible and num_sms % num_multicast == 0;
|
||||
}
|
||||
|
||||
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
|
||||
template <typename size_type_t>
|
||||
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
|
||||
// `> 0` means interleaving
|
||||
// 16B actually means non-swizzling (but interleaving)
|
||||
for (const int& mode: {128, 64, 32, 16}) {
|
||||
if ((block_size * elem_size) % mode == 0)
|
||||
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
|
||||
return mode;
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unreachable");
|
||||
}
|
||||
|
||||
template <typename ArchSpec>
|
||||
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
@@ -104,7 +107,7 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
|
||||
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
|
||||
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
|
||||
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
|
||||
const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0;
|
||||
|
||||
// Different archs have different epilogue pipelines
|
||||
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
|
||||
@@ -121,9 +124,11 @@ static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
// M-barriers and tensor memory pointers
|
||||
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
|
||||
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
|
||||
const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type);
|
||||
|
||||
// Sum them up
|
||||
int smem_size = 0;
|
||||
smem_size += smem_tensor_map;
|
||||
smem_size += smem_cd;
|
||||
smem_size += num_stages * smem_a_per_stage;
|
||||
smem_size += num_stages * smem_b_per_stage;
|
||||
@@ -151,15 +156,12 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
|
||||
|
||||
// Select M/N block sizes
|
||||
// TODO: support `% 16 == 8` block size on SM90
|
||||
auto block_ms = std::vector{64, 128, 256};
|
||||
if (gemm_type == GemmType::MGroupedContiguous)
|
||||
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
|
||||
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
|
||||
block_ms = std::vector{64, 128};
|
||||
std::vector<int> block_ns;
|
||||
for (int i = 16; i <= 256; i += 16)
|
||||
block_ns.push_back(i);
|
||||
const auto block_ns = ArchSpec::get_block_n_candidates(cd_dtype);
|
||||
|
||||
// K block size is selected in a fixed manner
|
||||
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
|
||||
@@ -214,9 +216,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
|
||||
|
||||
// Decide the number of TMA multicasts and whether broadcast on A
|
||||
MulticastConfig best_multicast_config = {1, true};
|
||||
MulticastConfig best_multicast_config = {1, false};
|
||||
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
|
||||
gemm_type, m, n, best_block_m, best_block_n, num_sms);
|
||||
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
|
||||
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
|
||||
bool order[2] = {false, true};
|
||||
if (best_block_m > best_block_n)
|
||||
@@ -232,11 +234,11 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
constexpr int smem_capacity = ArchSpec::smem_capacity;
|
||||
int best_num_stages = 0;
|
||||
SharedMemoryConfig best_smem_config;
|
||||
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
|
||||
for (int num_stages = 12; num_stages > 0; -- num_stages) {
|
||||
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
|
||||
continue;
|
||||
|
||||
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
|
||||
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
|
||||
m, n, k,
|
||||
best_block_m, best_block_n, block_k,
|
||||
major_a, major_b,
|
||||
|
||||
@@ -12,6 +12,15 @@ namespace deep_gemm {
|
||||
struct SM100ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
|
||||
// 16 is for better SM usage
|
||||
// Stride 32 is due to low-performance swizzle-16/32B
|
||||
std::vector<int> candidates = {16};
|
||||
for (int i = 32; i <= 256; i += 32)
|
||||
candidates.push_back(i);
|
||||
return candidates;
|
||||
}
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
|
||||
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
|
||||
}
|
||||
@@ -29,6 +38,10 @@ struct SM100ArchSpec {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
|
||||
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
|
||||
constexpr int num_utccp_aligned_elems = 128;
|
||||
@@ -86,7 +99,7 @@ struct SM100ArchSpec {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
// TODO: support other layouts
|
||||
@@ -138,12 +151,17 @@ struct SM100ArchSpec {
|
||||
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
|
||||
// NOTES: 1D2D kernel will not use the with-SF full barriers
|
||||
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
|
||||
return num_stages * 8 * 3 + 2 * 8 * 2;
|
||||
// NOTES: the last barrier is for tensor core utilization control
|
||||
return num_stages * 8 * 3 + 2 * 8 * 2 + 8;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 4;
|
||||
}
|
||||
|
||||
static int get_tensormap_smem_size(const GemmType& gemm_type) {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -11,6 +11,15 @@ namespace deep_gemm {
|
||||
struct SM90ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static std::vector<int> get_block_n_candidates(const at::ScalarType& cd_dtype) {
|
||||
// Avoid bank conflicts for FP32 output
|
||||
const auto& start = cd_dtype == torch::kFloat ? 8 : 16;
|
||||
std::vector<int> candidates;
|
||||
for (int i = start; i <= 256; i += 16)
|
||||
candidates.push_back(i);
|
||||
return candidates;
|
||||
}
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
|
||||
return block_m;
|
||||
}
|
||||
@@ -19,26 +28,35 @@ struct SM90ArchSpec {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_m(const int& block_m) {
|
||||
return block_m;
|
||||
static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) {
|
||||
constexpr int wgmma_m = 64;
|
||||
return single_warpgroup_sync ? wgmma_m : block_m;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_n(const int& block_n) {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
|
||||
return cd_dtype != torch::kFloat;
|
||||
}
|
||||
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// FP32 output does not support `block_m == 256`
|
||||
// SM90 FP32 output does not support `block_m == 256`
|
||||
if (cd_dtype == at::kFloat and block_m == 256)
|
||||
return false;
|
||||
|
||||
// TODO: more general block N selection
|
||||
// Must be some fixed block N selections
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152))
|
||||
return false;
|
||||
// Avoid large C/D shared memory for FP32 output
|
||||
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
|
||||
if (block_n > 128 and cd_dtype == torch::kFloat) {
|
||||
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
|
||||
return false;
|
||||
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
|
||||
return false;
|
||||
}
|
||||
|
||||
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
|
||||
// Or too many register spills
|
||||
@@ -66,9 +84,13 @@ struct SM90ArchSpec {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
// Disable multicast when the number of k-groups is large (a heuristic)
|
||||
if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4)
|
||||
return {false, false};
|
||||
|
||||
return {
|
||||
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
|
||||
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
|
||||
@@ -96,9 +118,10 @@ struct SM90ArchSpec {
|
||||
|
||||
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
|
||||
int smem_sfb_per_stage = 0;
|
||||
// TODO: figure out here
|
||||
if (kernel_type == KernelType::Kernel1D1D)
|
||||
smem_sfb_per_stage = align(block_n * 4, block_k);
|
||||
if (kernel_type == KernelType::Kernel1D1D) {
|
||||
// NOTES: `128` is for 2D TMA alignment requirement
|
||||
smem_sfb_per_stage = align(block_n * 4, 128);
|
||||
}
|
||||
return {smem_sfa_per_stage, smem_sfb_per_stage};
|
||||
}
|
||||
|
||||
@@ -109,13 +132,16 @@ struct SM90ArchSpec {
|
||||
}
|
||||
|
||||
static int get_barrier_smem_size(const int& num_stages) {
|
||||
// For 1D1D kernels, there is an extra barrier for accumulation
|
||||
return (num_stages + 1) * 8 * 2;
|
||||
return num_stages * 8 * 2;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int get_tensormap_smem_size(const GemmType& gemm_type) {
|
||||
return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
12
csrc/jit_kernels/impls/epilogue.hpp
Normal file
12
csrc/jit_kernels/impls/epilogue.hpp
Normal file
@@ -0,0 +1,12 @@
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <string>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
|
||||
return epilogue_type.value_or("EpilogueIdentity");
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -4,6 +4,8 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "../../utils/system.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -51,7 +53,11 @@ static std::string to_string(const at::ScalarType& dtype) {
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
|
||||
const bool& allow_tf32) {
|
||||
if (allow_tf32 and dtype == torch::kFloat)
|
||||
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
|
||||
|
||||
switch (dtype) {
|
||||
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
@@ -61,9 +67,14 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
|
||||
if (base != 0) {
|
||||
DG_HOST_ASSERT(base == 32 and mode == 128);
|
||||
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
|
||||
}
|
||||
|
||||
switch (mode) {
|
||||
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 0:
|
||||
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
@@ -76,7 +87,8 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
int gmem_inner_dim, int gmem_outer_dim,
|
||||
int smem_inner_dim, int smem_outer_dim,
|
||||
const int& gmem_outer_stride,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
@@ -87,14 +99,42 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
|
||||
const cuuint32_t elem_strides[2] = {1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n",
|
||||
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
|
||||
gmem_outer_stride, swizzle_mode, elem_size);
|
||||
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
|
||||
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
|
||||
const int& gmem_dim_0, const int& gmem_dim_1, const int& gmem_dim_2,
|
||||
const int& smem_dim_0, const int& smem_dim_1, const int& smem_dim_2,
|
||||
const int& gmem_stride_0, const int& gmem_stride_1,
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
DG_HOST_ASSERT(smem_dim_0 == swizzle_mode / elem_size);
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
|
||||
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
|
||||
const cuuint64_t gmem_strides[2] = {static_cast<cuuint64_t>(gmem_stride_0 * elem_size), static_cast<cuuint64_t>(gmem_stride_1 * elem_size)};
|
||||
const cuuint32_t elem_strides[3] = {1, 1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n",
|
||||
gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2,
|
||||
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
|
||||
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
@@ -105,7 +145,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
const int& block_m, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
if (num_groups > 1)
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
|
||||
@@ -114,7 +155,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
gmem_inner_dim, gmem_outer_dim,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
@@ -123,7 +165,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
const int& block_n, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
|
||||
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
|
||||
|
||||
@@ -132,7 +175,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
gmem_inner_dim, gmem_outer_dim * num_groups,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
@@ -140,15 +184,16 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
const int& block_m, const int& block_n,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
|
||||
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
|
||||
return make_tma_2d_desc(t,
|
||||
shape_n, shape_m * num_groups,
|
||||
block_n, block_m,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
@@ -156,7 +201,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
int shape_mn, int shape_k,
|
||||
const int& block_mn, const int& block_k,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const int& swizzle_mode, const int& swizzle_base = 0,
|
||||
const bool& allow_tf32 = false) {
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
|
||||
|
||||
// TODO: maybe swizzle SF as well
|
||||
@@ -167,7 +213,8 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
|
||||
block_mn, 1,
|
||||
shape_mn,
|
||||
swizzle_mode);
|
||||
swizzle_mode, swizzle_base,
|
||||
allow_tf32);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -42,7 +42,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
@@ -56,7 +56,7 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
@@ -80,8 +80,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
// TODO: test other Ks
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
@@ -122,7 +121,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
|
||||
137
csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp
Normal file
137
csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp
Normal file
@@ -0,0 +1,137 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100BmkBnkMnRuntime final: public LaunchRuntime<SM100BmkBnkMnRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int s, m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int split_factor;
|
||||
int swizzle_ab_mode, swizzle_cd_mode;
|
||||
int num_stages;
|
||||
int num_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_bmn_bnk_mn_gemm_impl<
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.m, args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.split_factor,
|
||||
args.swizzle_ab_mode, args.swizzle_cd_mode,
|
||||
args.num_stages, args.num_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
|
||||
const torch::Tensor &b,
|
||||
const torch::Tensor &d,
|
||||
const int &s, const int &m, const int &n, const int &k) {
|
||||
constexpr int block_m = 128;
|
||||
constexpr int block_n = 128;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_threads = 128;
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
|
||||
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
|
||||
|
||||
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
|
||||
const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast<int>(d.element_size()));
|
||||
|
||||
// Get best config
|
||||
const int num_sms = device_runtime->get_num_sms();
|
||||
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
|
||||
const int num_sk_blocks = s * (k / block_k);
|
||||
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
|
||||
|
||||
// Select best number of stages
|
||||
// NOTES: we select 4 as start, as it is tested to be faster than values > 4
|
||||
int num_stages = 4, smem_size = 0;
|
||||
while (true) {
|
||||
const int& smem_cd = block_m * swizzle_cd_mode * 2;
|
||||
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages);
|
||||
const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size();
|
||||
|
||||
smem_size = 0;
|
||||
smem_size += smem_cd;
|
||||
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
|
||||
smem_size += smem_barrier;
|
||||
smem_size += smem_tmem_ptr;
|
||||
if (smem_size <= SM100ArchSpec::smem_capacity)
|
||||
break;
|
||||
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("S: %d, M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
|
||||
"stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n",
|
||||
s, m, n, k, block_m, block_n, block_k, split_factor,
|
||||
num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode);
|
||||
}
|
||||
|
||||
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
|
||||
|
||||
const SM100BmkBnkMnRuntime::Args& args = {
|
||||
.s = s, .m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.split_factor = split_factor,
|
||||
.swizzle_ab_mode = swizzle_ab_mode,
|
||||
.swizzle_cd_mode = swizzle_cd_mode,
|
||||
.num_stages = num_stages,
|
||||
.num_threads = num_threads,
|
||||
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100BmkBnkMnRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
|
||||
SM100BmkBnkMnRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
|
||||
#include "epilogue.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -18,6 +20,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -44,11 +47,12 @@ static void __instantiate_kernel() {{
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}, {}
|
||||
{}, {}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -57,11 +61,12 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype));
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -80,7 +85,8 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
@@ -99,7 +105,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -129,6 +135,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -186,6 +193,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -244,6 +252,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -324,6 +333,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
@@ -18,6 +18,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -46,7 +47,8 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
{}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -59,7 +61,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype));
|
||||
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -78,7 +81,8 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
@@ -98,7 +102,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -111,6 +115,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -164,6 +169,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -218,6 +224,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
@@ -41,7 +41,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -53,7 +53,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -73,10 +74,10 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& aligned_k = align(k, 64);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
@@ -102,7 +103,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
|
||||
131
csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp
Normal file
131
csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp
Normal file
@@ -0,0 +1,131 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90BmkBnkMnRuntime final: public LaunchRuntime<SM90BmkBnkMnRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int s, m, n, k;
|
||||
int block_m, block_n, block_k;
|
||||
int split_factor;
|
||||
int num_stages;
|
||||
int num_tma_threads, num_math_threads;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
float* d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_bmn_bnk_mn_gemm_impl<
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
args.m, args.n, args.k,
|
||||
args.block_m, args.block_n, args.block_k,
|
||||
args.split_factor,
|
||||
args.num_stages,
|
||||
args.num_tma_threads, args.num_math_threads);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.s, args.tensor_map_a, args.tensor_map_b, args.d));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
|
||||
const torch::Tensor &b,
|
||||
const torch::Tensor &d,
|
||||
const int &s, const int &m, const int &n, const int &k) {
|
||||
constexpr int block_m = 128;
|
||||
constexpr int block_n = 128;
|
||||
constexpr int block_k = 64;
|
||||
constexpr int num_tma_threads = 128;
|
||||
constexpr int num_math_threads = 256;
|
||||
DG_HOST_ASSERT(k % block_k == 0);
|
||||
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
|
||||
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
|
||||
|
||||
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
|
||||
DG_HOST_ASSERT(swizzle_ab_mode == 128);
|
||||
|
||||
// Get best config
|
||||
const int num_sms = device_runtime->get_num_sms();
|
||||
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
|
||||
const int num_sk_blocks = s * (k / block_k);
|
||||
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
|
||||
|
||||
// Select best number of stages
|
||||
int num_stages = 4, smem_size = 0;
|
||||
while (true) {
|
||||
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
|
||||
const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages);
|
||||
|
||||
smem_size = 0;
|
||||
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
|
||||
smem_size += smem_barrier;
|
||||
|
||||
if (smem_size <= SM90ArchSpec::smem_capacity)
|
||||
break;
|
||||
|
||||
-- num_stages;
|
||||
}
|
||||
DG_HOST_ASSERT(num_stages > 0);
|
||||
|
||||
// Print configs
|
||||
if (get_env("DG_JIT_DEBUG", 0)) {
|
||||
printf("S: %d, M: %d, N: %d, K: %d -> "
|
||||
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
|
||||
"stages: %d, shared memory: %d, swizzle AB: %d\n",
|
||||
s, m, n, k, block_m, block_n, block_k, split_factor,
|
||||
num_stages, smem_size, swizzle_ab_mode);
|
||||
}
|
||||
|
||||
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
|
||||
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
|
||||
|
||||
const SM90BmkBnkMnRuntime::Args& args = {
|
||||
.s = s, .m = m, .n = n, .k = k,
|
||||
.block_m = block_m, .block_n = block_n, .block_k = block_k,
|
||||
.split_factor = split_factor,
|
||||
.num_stages = num_stages,
|
||||
.num_tma_threads = num_tma_threads,
|
||||
.num_math_threads = num_math_threads,
|
||||
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.d = d.data_ptr<float>()
|
||||
};
|
||||
const auto& code = SM90BmkBnkMnRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
|
||||
SM90BmkBnkMnRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
214
csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Normal file
214
csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp
Normal file
@@ -0,0 +1,214 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime<SM90FP8Gemm1D1DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *gmem_a_ptr;
|
||||
void *gmem_b_ptr;
|
||||
void *grouped_layout;
|
||||
void *tensor_map_buffer;
|
||||
CUtensorMap tensor_map_a_base;
|
||||
CUtensorMap tensor_map_b_base;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
CUtensorMap tensor_map_sfb;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d1d_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.num_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.gmem_a_ptr, args.gmem_b_ptr,
|
||||
args.grouped_layout,
|
||||
args.tensor_map_buffer,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a_base, args.tensor_map_b_base,
|
||||
args.tensor_map_sfa, args.tensor_map_sfb,
|
||||
args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k, k, 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k, k, 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.gmem_a_ptr = nullptr,
|
||||
.gmem_b_ptr = nullptr,
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_buffer = nullptr,
|
||||
.tensor_map_a_base = tensor_map_a,
|
||||
.tensor_map_b_base = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
|
||||
|
||||
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n,
|
||||
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
|
||||
const torch::Tensor& tensor_map_buffer,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
// Get config using max K for better performance
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
const auto& max_k = *std::max_element(ks.begin(), ks.end());
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
|
||||
int first_k = 0, sum_k = 0, sum_sf_k = 0;
|
||||
for (int i = 0; i < num_groups; ++ i) {
|
||||
if (first_k == 0 and ks[i] != 0)
|
||||
first_k = ks[i];
|
||||
sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128);
|
||||
DG_HOST_ASSERT(ks[i] % 128 == 0);
|
||||
}
|
||||
const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k, first_k, 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k, first_k, 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.gmem_a_ptr = a.data_ptr(),
|
||||
.gmem_b_ptr = b.data_ptr(),
|
||||
.grouped_layout = ks_tensor.data_ptr(),
|
||||
.tensor_map_buffer = tensor_map_buffer.data_ptr(),
|
||||
.tensor_map_a_base = tensor_map_a_base,
|
||||
.tensor_map_b_base = tensor_map_b_base,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
|
||||
|
||||
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -8,6 +8,8 @@
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
|
||||
#include "epilogue.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
@@ -17,6 +19,7 @@ public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
const std::optional<std::string>& epilogue_type;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
@@ -43,7 +46,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -55,7 +58,8 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
|
||||
get_default_epilogue_type(args.epilogue_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
@@ -74,7 +78,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const std::string& compiled_dims,
|
||||
const std::optional<std::string>& epilogue_type = std::nullopt) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
@@ -98,7 +103,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
@@ -111,6 +116,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = epilogue_type,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -170,6 +176,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
@@ -230,6 +237,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.epilogue_type = std::nullopt,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
|
||||
151
csrc/jit_kernels/impls/smxx_cublaslt.hpp
Normal file
151
csrc/jit_kernels/impls/smxx_cublaslt.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDADataType.h>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld,
|
||||
const std::optional<int>& batch_count = std::nullopt,
|
||||
const std::optional<int>& batch_offset = std::nullopt) {
|
||||
cublasLtMatrixLayout_t layout;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld));
|
||||
if (batch_count.has_value()) {
|
||||
DG_HOST_ASSERT(batch_offset.has_value());
|
||||
|
||||
const int64_t batch_offset_int64 = batch_offset.value();
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value())));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64)));
|
||||
}
|
||||
return layout;
|
||||
}
|
||||
|
||||
static void call_cublaslt_api(const cublasOperation_t& trans_a,
|
||||
const cublasOperation_t& trans_b,
|
||||
const cublasLtMatrixLayout_t& layout_a,
|
||||
const cublasLtMatrixLayout_t& layout_b,
|
||||
const cublasLtMatrixLayout_t& layout_d,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const bool& accumulate) {
|
||||
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
|
||||
cudaDataType_t scale_type = CUDA_R_32F;
|
||||
const int& math_sms = device_runtime->get_num_sms();
|
||||
bool fp8_fast_accumulate = false;
|
||||
|
||||
// Operation description
|
||||
cublasLtMatmulDesc_t desc;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
|
||||
if (a.scalar_type() == torch::kFloat8_e4m3fn)
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
|
||||
|
||||
// Get cuBLASLt handle, workspace, and stream
|
||||
const auto& handle = device_runtime->get_cublaslt_handle();
|
||||
const auto& workspace = device_runtime->get_cublaslt_workspace();
|
||||
const auto& workspace_bytes = workspace.nbytes();
|
||||
const auto& stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
// Algorithm selection
|
||||
cublasLtMatmulPreference_t pref;
|
||||
cublasLtMatmulHeuristicResult_t heuristic;
|
||||
int num_heuristic_results = 0;
|
||||
uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
|
||||
&workspace_bytes, sizeof(workspace_bytes)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
|
||||
&reduction_scheme_mask, sizeof(reduction_scheme_mask)));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d,
|
||||
pref, 1, &heuristic, &num_heuristic_results));
|
||||
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
|
||||
|
||||
// Call: D = alpha * (A @ B) + beta * C
|
||||
const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
|
||||
desc, // Operation description
|
||||
&alpha, // Alpha
|
||||
b.data_ptr(), layout_a, // A
|
||||
a.data_ptr(), layout_b, // B
|
||||
&beta, // Beta
|
||||
d.data_ptr(), layout_d, // C
|
||||
d.data_ptr(), layout_d, // D
|
||||
&heuristic.algo, // Algorithm
|
||||
workspace.data_ptr(), workspace_bytes, // Workspace
|
||||
stream)); // Stream
|
||||
|
||||
// Free memory
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d));
|
||||
DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc));
|
||||
}
|
||||
|
||||
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
|
||||
const std::optional<torch::Tensor>& acc,
|
||||
const torch::Tensor& out,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) {
|
||||
const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
// TODO: remove this
|
||||
if (acc.has_value()) {
|
||||
if (acc->data_ptr() == out.data_ptr()) {
|
||||
DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides());
|
||||
} else {
|
||||
out.copy_(acc.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Matrix layouts
|
||||
const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
|
||||
const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
|
||||
const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
|
||||
const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
|
||||
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
|
||||
const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
|
||||
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value());
|
||||
}
|
||||
|
||||
|
||||
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
|
||||
const int& b, const int& h, const int& r, const int& d) {
|
||||
const auto& m = d, n = b, k = r;
|
||||
const auto& trans_a = CUBLAS_OP_T;
|
||||
const auto& trans_b = CUBLAS_OP_N;
|
||||
|
||||
// Matrix layouts
|
||||
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
|
||||
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
|
||||
}
|
||||
|
||||
|
||||
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
|
||||
const int& b, const int& h, const int& r, const int& d) {
|
||||
const auto& m = r, n = b, k = d;
|
||||
const auto& trans_a = CUBLAS_OP_N;
|
||||
const auto& trans_b = CUBLAS_OP_N;
|
||||
|
||||
// Matrix layouts
|
||||
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
|
||||
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
|
||||
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
|
||||
|
||||
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,6 +1,8 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "apis/attention.hpp"
|
||||
#include "apis/einsum.hpp"
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
@@ -13,6 +15,8 @@
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
deep_gemm::attention::register_apis(m);
|
||||
deep_gemm::einsum::register_apis(m);
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cublasLt.h>
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
@@ -72,4 +73,16 @@ do { \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUBLASLT_CHECK
|
||||
#define DG_CUBLASLT_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != CUBLAS_STATUS_SUCCESS) { \
|
||||
std::ostringstream ss; \
|
||||
ss << static_cast<int>(e) << " (" << cublasGetStatusString(e) << ")"; \
|
||||
throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <filesystem>
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <unistd.h>
|
||||
|
||||
#include "exception.hpp"
|
||||
#include "format.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -65,7 +69,10 @@ static std::filesystem::path make_dirs(const std::filesystem::path& path) {
|
||||
// OK if existed
|
||||
std::error_code capture;
|
||||
const bool& created = std::filesystem::create_directories(path, capture);
|
||||
DG_HOST_ASSERT(created or capture.value() == 0);
|
||||
if (not (created or capture.value() == 0)) {
|
||||
DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}",
|
||||
path.c_str(), created, capture.value()));
|
||||
}
|
||||
if (created and get_env<int>("DG_JIT_DEBUG"))
|
||||
printf("Create directory: %s\n", path.c_str());
|
||||
return path;
|
||||
|
||||
Reference in New Issue
Block a user