* Merge with private repo * Update README * Update README * Update README * Add PyTorch requirements * Fix sync scopes for MQA logits (#256) * Update README
716 lines
35 KiB
C++
716 lines
35 KiB
C++
#pragma once
|
|
|
|
#include "../utils/compatibility.hpp"
|
|
|
|
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
|
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
|
|
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
|
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
|
|
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
|
|
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
|
|
#endif
|
|
|
|
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
|
|
|
|
#include "layout.hpp"
|
|
|
|
namespace deep_gemm::gemm {
|
|
|
|
static bool early_return(const int& m, const int &n, const int& k,
|
|
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
|
// Do nothing if the problem is empty
|
|
if (m == 0 or n == 0)
|
|
return true;
|
|
|
|
// Checks
|
|
const bool is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr();
|
|
if (is_cd_same)
|
|
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
|
if (c.has_value()) {
|
|
check_major_type_cd(c.value());
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
|
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
|
}
|
|
|
|
// No accumulation
|
|
if (k == 0) {
|
|
if (not is_cd_same)
|
|
c.has_value() ? d.copy_(c.value()) : d.zero_();
|
|
return true;
|
|
}
|
|
|
|
// With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D)
|
|
if (c.has_value() and not is_cd_same)
|
|
d.copy_(c.value());
|
|
return false;
|
|
}
|
|
|
|
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
|
|
|
static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
std::optional<std::tuple<int, int, int>> recipe,
|
|
std::optional<std::tuple<int, int>> recipe_a,
|
|
std::optional<std::tuple<int, int>> recipe_b,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast) {
|
|
// Shape must be `[M, K] @ [N, K].T`
|
|
const auto major_a = get_major_type_ab(a.first);
|
|
const auto major_b = get_major_type_ab(b.first);
|
|
if (fp8_requires_k_major()) {
|
|
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
|
}
|
|
|
|
// C/D must be N-major
|
|
check_major_type_cd(d);
|
|
|
|
// Type and shape checks
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
|
|
const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major);
|
|
const auto [m_, n_] = get_shape<2>(d);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
|
|
|
// Early return for trivial cases
|
|
if (early_return(m, n, k, d, c))
|
|
return;
|
|
|
|
// Transform SFA and SFB into compute-required layout
|
|
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
|
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast);
|
|
|
|
// Dispatch into different implements
|
|
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
|
const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value());
|
|
if (gran_n == 1) {
|
|
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
|
} else {
|
|
const auto major_sfb = get_major_type_ab(sfb);
|
|
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
|
|
}
|
|
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
|
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b,
|
|
major_a, major_b, compiled_dims);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
|
}
|
|
}
|
|
|
|
static void fp8_fp4_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::optional<std::tuple<int, int, int>>& recipe,
|
|
const std::optional<std::tuple<int, int>>& recipe_a,
|
|
const std::optional<std::tuple<int, int>>& recipe_b,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast) {
|
|
fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
|
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
|
|
}
|
|
|
|
static void fp8_fp4_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::optional<std::tuple<int, int, int>>& recipe,
|
|
const std::optional<std::tuple<int, int>>& recipe_a,
|
|
const std::optional<std::tuple<int, int>>& recipe_b,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast) {
|
|
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
|
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
|
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
|
|
}
|
|
|
|
static void fp8_fp4_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::optional<std::tuple<int, int, int>>& recipe,
|
|
const std::optional<std::tuple<int, int>>& recipe_a,
|
|
const std::optional<std::tuple<int, int>>& recipe_b,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast) {
|
|
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
|
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
|
|
}
|
|
|
|
static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const torch::Tensor& grouped_layout,
|
|
std::optional<std::tuple<int, int, int>> recipe,
|
|
std::optional<std::tuple<int, int>> recipe_a,
|
|
std::optional<std::tuple<int, int>> recipe_b,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast,
|
|
const bool& use_psum_layout,
|
|
const std::optional<int>& expected_m_for_psum_layout) {
|
|
// Shape must be `[M, K] @ [G, N, K].mT`
|
|
const auto major_a = get_major_type_ab(a.first);
|
|
const auto major_b = get_major_type_ab(b.first);
|
|
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
|
if (fp8_requires_k_major())
|
|
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(grouped_layout.is_contiguous());
|
|
|
|
// Type and shape checks
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
|
|
const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
|
|
const auto [m_, n_] = get_shape<2>(d);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
|
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
|
|
|
|
// Layout checks
|
|
if (use_psum_layout) {
|
|
const auto [num_groups_] = get_shape<1>(grouped_layout);
|
|
DG_HOST_ASSERT(num_groups == num_groups_);
|
|
} else {
|
|
const auto [m__] = get_shape<1>(grouped_layout);
|
|
DG_HOST_ASSERT(m == m__);
|
|
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
|
|
}
|
|
|
|
// D must be N-major
|
|
check_major_type_cd(d);
|
|
|
|
// Do nothing if empty
|
|
if (m == 0)
|
|
return;
|
|
|
|
// Transform SFA and SFB into compute-required layout
|
|
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
|
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, num_groups, disable_ue8m0_cast);
|
|
|
|
// Dispatch implementation
|
|
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
|
const auto major_sfb = get_major_type_ab(sfb);
|
|
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout,
|
|
num_groups, m, n, k, major_a, major_b, major_sfb,
|
|
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
|
|
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
|
sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
|
|
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b,
|
|
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
|
}
|
|
}
|
|
|
|
static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const torch::Tensor& grouped_layout,
|
|
const std::optional<std::tuple<int, int, int>>& recipe,
|
|
const std::optional<std::tuple<int, int>>& recipe_a,
|
|
const std::optional<std::tuple<int, int>>& recipe_b,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast,
|
|
const bool& use_psum_layout) {
|
|
m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
|
d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt);
|
|
}
|
|
|
|
static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const torch::Tensor& masked_m,
|
|
const int& expected_m,
|
|
std::optional<std::tuple<int, int, int>> recipe,
|
|
std::optional<std::tuple<int, int>> recipe_a,
|
|
std::optional<std::tuple<int, int>> recipe_b,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast) {
|
|
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
|
const auto major_a = get_major_type_ab(a.first);
|
|
const auto major_b = get_major_type_ab(b.first);
|
|
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(masked_m.is_contiguous());
|
|
|
|
// Type and shape checks
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
|
|
const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
|
|
const auto [num_groups__, m_, n_] = get_shape<3>(d);
|
|
const auto num_groups___ = static_cast<int>(masked_m.numel());
|
|
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
|
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
|
|
|
// D must be N-major
|
|
check_major_type_cd(d);
|
|
|
|
// Transform scaling factors
|
|
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
|
|
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast);
|
|
|
|
// Dispatch implementation
|
|
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
|
const auto major_sfb = get_major_type_ab(sfb);
|
|
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
|
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
|
|
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
|
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
|
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
|
|
major_a, major_b, compiled_dims);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
|
}
|
|
}
|
|
|
|
static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const std::vector<int>& ks,
|
|
const torch::Tensor& ks_tensor,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::tuple<int, int, int>& recipe,
|
|
const std::string& compiled_dims) {
|
|
// Must be 1D1D kernel
|
|
DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1);
|
|
|
|
const int gran_k = std::get<2>(recipe);
|
|
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
|
|
|
|
// Shape checks
|
|
const auto [num_groups, m, n] = get_shape<3>(d);
|
|
const auto [sum_k_ , m_] = get_shape<2>(a.first);
|
|
const auto [sum_k__, n_] = get_shape<2>(b.first);
|
|
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
|
|
|
|
// Contiguity checks
|
|
DG_HOST_ASSERT(a.first.is_contiguous());
|
|
DG_HOST_ASSERT(b.first.is_contiguous());
|
|
DG_HOST_ASSERT(d.is_contiguous());
|
|
DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous());
|
|
|
|
// Early return for trivial cases
|
|
if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c))
|
|
return;
|
|
|
|
// Transform SF with padding
|
|
const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
|
const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
|
|
|
// Dispatch implementation
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
if (arch_major == 10) {
|
|
sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, gran_k,
|
|
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
}
|
|
|
|
static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const std::vector<int>& ks,
|
|
const torch::Tensor& ks_tensor,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::tuple<int, int, int>& recipe,
|
|
const std::string& compiled_dims) {
|
|
// Must be 1D1D kernel
|
|
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
|
|
|
// Shape checks
|
|
const auto [num_groups, m, n] = get_shape<3>(d);
|
|
const auto sum_mk = a.first.numel();
|
|
const auto sum_nk = b.first.numel();
|
|
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
|
|
DG_HOST_ASSERT(sum_mk == static_cast<int64_t>(sum_k) * m);
|
|
DG_HOST_ASSERT(sum_nk == static_cast<int64_t>(sum_k) * n);
|
|
|
|
// Contiguity checks
|
|
DG_HOST_ASSERT(a.first.is_contiguous());
|
|
DG_HOST_ASSERT(b.first.is_contiguous());
|
|
DG_HOST_ASSERT(d.is_contiguous());
|
|
DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous());
|
|
|
|
// Early return for trivial cases
|
|
if (early_return(m, n, accumulate(ks.begin(), ks.end(), 0), d, c))
|
|
return;
|
|
|
|
// Transform SF with padding
|
|
const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
|
const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
|
|
|
// Allocate tensormap buffer
|
|
// `4` means the double buffering for both A and B operands (2 * 2)
|
|
const auto num_sms = device_runtime->get_num_sms();
|
|
const auto tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
|
|
a.first.options().dtype(torch::kByte));
|
|
|
|
// Dispatch implementation
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
if (arch_major == 9) {
|
|
sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
|
|
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
}
|
|
#endif
|
|
|
|
#if DG_TENSORMAP_COMPATIBLE
|
|
static void bf16_gemm_nt(const torch::Tensor& a,
|
|
const torch::Tensor& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::string& compiled_dims) {
|
|
// Shape must be `[M, K] @ [N, K].T`
|
|
const auto major_a = get_major_type_ab(a);
|
|
const auto major_b = get_major_type_ab(b);
|
|
|
|
// C/D must be N-major
|
|
check_major_type_cd(d);
|
|
|
|
// Type and shape checks
|
|
const auto [m , k ] = get_shape<2>(a);
|
|
const auto [n , k_] = get_shape<2>(b);
|
|
const auto [m_, n_] = get_shape<2>(d);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
|
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
|
|
|
// Early return for trivial cases
|
|
if (early_return(m, n, k, d, c))
|
|
return;
|
|
|
|
// Dispatch into different implements
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
if (arch_major == 9) {
|
|
sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
|
|
} else if (arch_major == 10) {
|
|
sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
}
|
|
|
|
static void bf16_gemm_nn(const torch::Tensor& a,
|
|
const torch::Tensor& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::string& compiled_dims) {
|
|
bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims);
|
|
}
|
|
|
|
static void bf16_gemm_tn(const torch::Tensor& a,
|
|
const torch::Tensor& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::string& compiled_dims) {
|
|
bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims);
|
|
}
|
|
|
|
static void bf16_gemm_tt(const torch::Tensor& a,
|
|
const torch::Tensor& b,
|
|
const torch::Tensor& d,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::string& compiled_dims) {
|
|
bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims);
|
|
}
|
|
|
|
static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b,
|
|
const torch::Tensor& d, const torch::Tensor& grouped_layout,
|
|
const std::string& compiled_dims,
|
|
const bool& use_psum_layout,
|
|
const std::optional<int>& expected_m_for_psum_layout) {
|
|
// Shape must be `[M, K] @ [G, N, K].mT`
|
|
const auto major_a = get_major_type_ab(a);
|
|
const auto major_b = get_major_type_ab(b);
|
|
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(grouped_layout.is_contiguous());
|
|
|
|
// Type and shape checks
|
|
const auto [m, k] = get_shape<2>(a);
|
|
const auto [num_groups, n, k_] = get_shape<3>(b);
|
|
const auto [m_, n_] = get_shape<2>(d);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
|
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
|
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
|
|
|
|
// Layout checks
|
|
if (use_psum_layout) {
|
|
const auto [num_groups_] = get_shape<1>(grouped_layout);
|
|
DG_HOST_ASSERT(num_groups == num_groups_);
|
|
} else {
|
|
const auto [m__] = get_shape<1>(grouped_layout);
|
|
DG_HOST_ASSERT(m == m__);
|
|
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
|
|
}
|
|
|
|
// D must be N-major
|
|
check_major_type_cd(d);
|
|
|
|
// Do nothing if empty
|
|
if (m == 0)
|
|
return;
|
|
|
|
// Dispatch implementation
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
if (arch_major == 9) {
|
|
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
|
|
num_groups, m, n, k, major_a, major_b, compiled_dims,
|
|
use_psum_layout, expected_m_for_psum_layout);
|
|
} else if (arch_major == 10) {
|
|
sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
|
|
num_groups, m, n, k, major_a, major_b, compiled_dims,
|
|
use_psum_layout, expected_m_for_psum_layout);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
}
|
|
|
|
static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b,
|
|
const torch::Tensor& d, const torch::Tensor& grouped_layout,
|
|
const std::string& compiled_dims,
|
|
const bool& use_psum_layout) {
|
|
m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2),
|
|
d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt);
|
|
}
|
|
|
|
static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b,
|
|
const torch::Tensor& d, const torch::Tensor& masked_m,
|
|
const int& expected_m, const std::string& compiled_dims) {
|
|
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
|
const auto major_a = get_major_type_ab(a);
|
|
const auto major_b = get_major_type_ab(b);
|
|
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(masked_m.is_contiguous());
|
|
|
|
// Type and shape checks
|
|
const auto [num_groups, m, k] = get_shape<3>(a);
|
|
const auto [num_groups_, n, k_] = get_shape<3>(b);
|
|
const auto [num_groups__, m_, n_] = get_shape<3>(d);
|
|
const auto num_groups___ = static_cast<int>(masked_m.numel());
|
|
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
|
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
|
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
|
|
|
// D must be N-major
|
|
check_major_type_cd(d);
|
|
|
|
// Dispatch implementation
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
if (arch_major == 9) {
|
|
sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m,
|
|
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
|
} else if (arch_major == 10) {
|
|
sm100_m_grouped_bf16_gemm_masked(a, b, d, masked_m,
|
|
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
}
|
|
|
|
static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
|
|
const torch::Tensor& b,
|
|
const torch::Tensor& d,
|
|
const std::vector<int>& ks,
|
|
const torch::Tensor& ks_tensor,
|
|
const std::optional<torch::Tensor>& c,
|
|
const std::string& compiled_dims) {
|
|
// Shape checks
|
|
const auto [num_groups, m, n] = get_shape<3>(d);
|
|
const auto [sum_k_ , m_] = get_shape<2>(a);
|
|
const auto [sum_k__, n_] = get_shape<2>(b);
|
|
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
|
|
|
|
// Contiguity checks
|
|
DG_HOST_ASSERT(a.is_contiguous());
|
|
DG_HOST_ASSERT(b.is_contiguous());
|
|
DG_HOST_ASSERT(d.is_contiguous());
|
|
DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous());
|
|
|
|
// Early return for trivial cases
|
|
if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c))
|
|
return;
|
|
|
|
// Dispatch implementation
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
if (arch_major == 9) {
|
|
sm90_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor,
|
|
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
|
} else if (arch_major == 10) {
|
|
sm100_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor,
|
|
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
}
|
|
#endif
|
|
|
|
static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
|
|
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
|
// Shape must be `[M, K] @ [N, K].T`
|
|
const auto major_a = get_major_type_ab(a);
|
|
const auto major_b = get_major_type_ab(b);
|
|
|
|
// Type and shape checks
|
|
const auto [m , k ] = get_shape<2>(a);
|
|
const auto [n , k_] = get_shape<2>(b);
|
|
const auto [m_, n_] = get_shape<2>(d);
|
|
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
|
|
|
// Early return for trivial cases
|
|
if (early_return(m, n, k, d, c))
|
|
return;
|
|
|
|
cublaslt_gemm(a, b, d, m, n, k, major_a, major_b, c.has_value());
|
|
}
|
|
|
|
static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b,
|
|
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
|
cublaslt_gemm_nt(a, b.transpose(0, 1), d, c);
|
|
}
|
|
|
|
static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b,
|
|
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
|
cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c);
|
|
}
|
|
|
|
static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
|
|
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
|
|
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
|
|
}
|
|
|
|
static void register_apis(pybind11::module_& m) {
|
|
|
|
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
|
// FP8 FP4 GEMMs
|
|
m.def("fp8_fp4_gemm_nt", &fp8_fp4_gemm_nt,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
|
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk",
|
|
py::arg("disable_ue8m0_cast") = false);
|
|
m.def("fp8_fp4_gemm_nn", &fp8_fp4_gemm_nn,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
|
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk",
|
|
py::arg("disable_ue8m0_cast") = false);
|
|
m.def("fp8_fp4_gemm_tn", &fp8_fp4_gemm_tn,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
|
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
|
py::arg("compiled_dims") = "mn",
|
|
py::arg("disable_ue8m0_cast") = false);
|
|
m.def("fp8_fp4_gemm_tt", &fp8_fp4_gemm_tt,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
|
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
|
py::arg("compiled_dims") = "mn",
|
|
py::arg("disable_ue8m0_cast") = false);
|
|
m.def("m_grouped_fp8_fp4_gemm_nt_contiguous", &m_grouped_fp8_fp4_gemm_nt_contiguous,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
|
py::arg("recipe") = std::nullopt,
|
|
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk",
|
|
py::arg("disable_ue8m0_cast") = false,
|
|
py::arg("use_psum_layout") = false,
|
|
py::arg("expected_m_for_psum_layout") = std::nullopt);
|
|
m.def("m_grouped_fp8_fp4_gemm_nn_contiguous", &m_grouped_fp8_fp4_gemm_nn_contiguous,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
|
py::arg("recipe") = std::nullopt,
|
|
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk",
|
|
py::arg("disable_ue8m0_cast") = false,
|
|
py::arg("use_psum_layout") = false);
|
|
m.def("m_grouped_fp8_fp4_gemm_nt_masked", &m_grouped_fp8_fp4_gemm_nt_masked,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
|
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
|
|
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
|
|
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
|
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
|
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
|
py::arg("compiled_dims") = "mn");
|
|
m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
|
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
|
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
|
py::arg("compiled_dims") = "mn");
|
|
|
|
// FP8 GEMM alias names
|
|
m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt");
|
|
m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn");
|
|
m.attr("fp8_gemm_tn") = m.attr("fp8_fp4_gemm_tn");
|
|
m.attr("fp8_gemm_tt") = m.attr("fp8_fp4_gemm_tt");
|
|
m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous");
|
|
m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous");
|
|
m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked");
|
|
#endif
|
|
|
|
#if DG_TENSORMAP_COMPATIBLE
|
|
// BF16 GEMMs
|
|
m.def("bf16_gemm_nt", &bf16_gemm_nt,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk");
|
|
m.def("bf16_gemm_nn", &bf16_gemm_nn,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk");
|
|
m.def("bf16_gemm_tn", &bf16_gemm_tn,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt,
|
|
py::arg("compiled_dims") = "mn");
|
|
m.def("bf16_gemm_tt", &bf16_gemm_tt,
|
|
py::arg("a"), py::arg("b"), py::arg("d"),
|
|
py::arg("c") = std::nullopt,
|
|
py::arg("compiled_dims") = "mn");
|
|
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
|
py::arg("compiled_dims") = "nk",
|
|
py::arg("use_psum_layout") = false,
|
|
py::arg("expected_m_for_psum_layout") = std::nullopt);
|
|
m.def("m_grouped_bf16_gemm_nn_contiguous", &m_grouped_bf16_gemm_nn_contiguous,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
|
|
py::arg("compiled_dims") = "nk",
|
|
py::arg("use_psum_layout") = false);
|
|
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
|
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
|
|
m.def("k_grouped_bf16_gemm_tn_contiguous", &k_grouped_bf16_gemm_tn_contiguous,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
|
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
|
py::arg("compiled_dims") = "mn");
|
|
#endif
|
|
|
|
// cuBLASLt GEMMs
|
|
m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
|
m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
|
m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
|
m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
|
|
}
|
|
|
|
} // namespace deep_gemm::gemm
|