Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
This commit is contained in:
471
csrc/apis/gemm.hpp
Normal file
471
csrc/apis/gemm.hpp
Normal file
@@ -0,0 +1,471 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
|
||||
|
||||
#include "layout.hpp"
|
||||
|
||||
namespace deep_gemm::gemm {
|
||||
|
||||
static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
if (fp8_requires_k_major()) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
}
|
||||
|
||||
// C/D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check C as well
|
||||
if (c.has_value()) {
|
||||
check_major_type_cd(c.value());
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
||||
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
if (fp8_requires_k_major())
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a.first);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Do nothing if empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
||||
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& expected_m,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a.first);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Must be 1D1D kernel
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
DG_HOST_ASSERT(b.first.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().is_contiguous());
|
||||
}
|
||||
|
||||
// Do nothing if empty
|
||||
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
|
||||
return;
|
||||
|
||||
// Transform SF with padding
|
||||
const auto& [_, m] = get_shape<2>(a.first);
|
||||
const auto& [__, n] = get_shape<2>(b.first);
|
||||
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
||||
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 10) {
|
||||
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bf16_gemm_nt(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
|
||||
// C/D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a);
|
||||
const auto& [n , k_] = get_shape<2>(b);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check C as well
|
||||
if (c.has_value()) {
|
||||
check_major_type_cd(c.value());
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10) {
|
||||
sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bf16_gemm_nn(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims);
|
||||
}
|
||||
|
||||
static void bf16_gemm_tn(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims);
|
||||
}
|
||||
|
||||
static void bf16_gemm_tt(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims);
|
||||
}
|
||||
|
||||
static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const torch::Tensor& m_indices,
|
||||
const std::string& compiled_dims) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Do nothing if empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const torch::Tensor& masked_m,
|
||||
const int& expected_m, const std::string& compiled_dims) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
// FP8 GEMMs
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_nn", &fp8_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tn", &fp8_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tt", &fp8_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
|
||||
// BF16 GEMMs
|
||||
m.def("bf16_gemm_nt", &bf16_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk");
|
||||
m.def("bf16_gemm_nn", &bf16_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk");
|
||||
m.def("bf16_gemm_tn", &bf16_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("bf16_gemm_tt", &bf16_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("compiled_dims") = "nk");
|
||||
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::gemm
|
||||
85
csrc/apis/layout.hpp
Normal file
85
csrc/apis/layout.hpp
Normal file
@@ -0,0 +1,85 @@
|
||||
#pragma once
|
||||
|
||||
#include "../utils/layout.hpp"
|
||||
#include "../jit_kernels/impls/smxx_layout.hpp"
|
||||
|
||||
namespace deep_gemm::layout {
|
||||
|
||||
static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
const auto& gran_k = std::get<2>(recipe);
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// Pre-transform checks
|
||||
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
|
||||
|
||||
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
|
||||
}
|
||||
|
||||
// (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
|
||||
|
||||
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
|
||||
}
|
||||
|
||||
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
}
|
||||
|
||||
static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::tuple<int, int, int>& recipe) {
|
||||
DG_HOST_ASSERT(sf.dim() == 2);
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// FP32 on SM90
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
// FP32 on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
|
||||
|
||||
// INT on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
|
||||
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
|
||||
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
|
||||
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::layout
|
||||
28
csrc/apis/runtime.hpp
Normal file
28
csrc/apis/runtime.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit/compiler.hpp"
|
||||
#include "../jit/device_runtime.hpp"
|
||||
|
||||
namespace deep_gemm::runtime {
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
});
|
||||
m.def("get_num_sms", [&]() {
|
||||
return device_runtime->get_num_sms();
|
||||
});
|
||||
m.def("set_tc_util", [&](const int& new_tc_util) {
|
||||
device_runtime->set_tc_util(new_tc_util);
|
||||
});
|
||||
m.def("get_tc_util", [&]() {
|
||||
return device_runtime->get_tc_util();
|
||||
});
|
||||
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
|
||||
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
|
||||
KernelRuntime::prepare_init(cuda_home_path_by_python);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::runtime
|
||||
@@ -35,10 +35,10 @@ public:
|
||||
}
|
||||
|
||||
static void prepare_init(const std::string& library_root_path,
|
||||
const std::string& cuda_home_path_by_torch) {
|
||||
const std::string& cuda_home_path_by_python) {
|
||||
Compiler::library_root_path = library_root_path;
|
||||
Compiler::library_include_path = Compiler::library_root_path / "include";
|
||||
Compiler::cuda_home = cuda_home_path_by_torch;
|
||||
Compiler::cuda_home = cuda_home_path_by_python;
|
||||
Compiler::library_version = get_library_version();
|
||||
}
|
||||
|
||||
@@ -64,6 +64,8 @@ public:
|
||||
get_env<int>("DG_JIT_CPP_STANDARD", 20));
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
|
||||
flags += " --ptxas-options=--verbose";
|
||||
if (get_env("DG_JIT_WITH_LINEINFO", 0))
|
||||
flags += " -Xcompiler -rdynamic -lineinfo";
|
||||
}
|
||||
|
||||
virtual ~Compiler() = default;
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
#if CUDART_VERSION >= 12080 or defined(DG_JIT_USE_DRIVER_API)
|
||||
#if CUDART_VERSION >= 12080 and not defined(DG_JIT_USE_DRIVER_API)
|
||||
|
||||
// Use CUDA runtime API
|
||||
using LibraryHandle = cudaLibrary_t;
|
||||
|
||||
@@ -64,8 +64,8 @@ public:
|
||||
kernel = load_kernel(cubin_path, symbol_names[0], &library);
|
||||
}
|
||||
|
||||
static void prepare_init(const std::string& cuda_home_path_by_torch) {
|
||||
cuda_home = cuda_home_path_by_torch;
|
||||
static void prepare_init(const std::string& cuda_home_path_by_python) {
|
||||
cuda_home = cuda_home_path_by_python;
|
||||
}
|
||||
|
||||
static bool check_validity(const std::filesystem::path& dir_path) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/layout.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -146,7 +147,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const bool& with_accumulation, const int& num_sms) {
|
||||
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
|
||||
|
||||
// Select M/N block sizes
|
||||
@@ -179,7 +180,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
for (const auto& block_n: block_ns) {
|
||||
const int& num_waves = get_num_waves(block_m, block_n);
|
||||
const auto& last_util = get_last_wave_util(block_m, block_n);
|
||||
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n))
|
||||
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k))
|
||||
continue;
|
||||
|
||||
bool success = false;
|
||||
|
||||
@@ -43,7 +43,7 @@ struct SM100ArchSpec {
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n) {
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// TODO: consider more carefully for BF16 GEMMs
|
||||
// 2SM BF16 UMMA does not support `N % 32 != 0`
|
||||
if (ab_dtype == torch::kBFloat16 and block_n % 32 != 0)
|
||||
@@ -106,7 +106,7 @@ struct SM100ArchSpec {
|
||||
const int& swizzle_cd_mode,
|
||||
const at::ScalarType& cd_dtype) {
|
||||
constexpr static int layout_ad_m = 128;
|
||||
return (kernel_type == KernelType::Kernel1D1D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2;
|
||||
return (kernel_type != KernelType::Kernel1D2D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2;
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
|
||||
|
||||
@@ -30,15 +30,19 @@ struct SM90ArchSpec {
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n) {
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// FP32 output does not support `block_m == 256`
|
||||
if (cd_dtype == at::kFloat and block_m == 256)
|
||||
return false;
|
||||
|
||||
// TODO: more general block N selection
|
||||
// Must be some fixed block N selections
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 or block_n != 152))
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152))
|
||||
return false;
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 or block_n != 160))
|
||||
|
||||
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
|
||||
// Or too many register spills
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
|
||||
return false;
|
||||
|
||||
// Avoid bank conflicts for FP32 output
|
||||
|
||||
143
csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Normal file
143
csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Normal file
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100BF16GemmRuntime final: public LaunchRuntime<SM100BF16GemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void* grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_c;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_bf16_gemm_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
|
||||
args.gemm_config.tc_util);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.grouped_layout, args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_c, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
// TODO: test other Ks
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& cd = c.value_or(d);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(cd.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
if (c->data_ptr() == d.data_ptr()) {
|
||||
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
} else {
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
d.copy_(c.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_c = tensor_map_c,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_bf16_gemm", code);
|
||||
SM100BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -155,7 +156,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -202,7 +203,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -136,7 +137,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -179,7 +180,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
229
csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Normal file
229
csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Normal file
@@ -0,0 +1,229 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90BF16GemmRuntime final: public LaunchRuntime<SM90BF16GemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_bf16_gemm_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
// TODO: add CD dtype
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bf16_gemm", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::KernelNoSF,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -139,7 +140,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -185,7 +186,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
@@ -10,6 +10,35 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class TransposeFP32Runtime final: public LaunchRuntime<TransposeFP32Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
int mn, sf_k;
|
||||
int block_mn;
|
||||
void *sf, *out;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&transpose_fp32<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast<uint32_t>(args.mn)));
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
@@ -88,10 +117,32 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
|
||||
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
|
||||
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
|
||||
|
||||
// Normal layout requires transposing
|
||||
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
|
||||
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
|
||||
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
|
||||
const auto& out = torch::empty_strided({num_groups, mn, sf_k},
|
||||
{tma_aligned_mn * sf_k, 1, tma_aligned_mn},
|
||||
batched_sf.options());
|
||||
|
||||
if (not batched_sf.is_contiguous()) {
|
||||
// Fallback to PyTorch's slow copy if not contiguous
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
out.copy_(batched_sf);
|
||||
} else {
|
||||
constexpr int block_mn = 64;
|
||||
constexpr int num_threads = 512;
|
||||
const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast<int>(sizeof(float));
|
||||
const TransposeFP32Runtime::Args& args = {
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.block_mn = block_mn,
|
||||
.sf = batched_sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size)
|
||||
};
|
||||
|
||||
const auto& code = TransposeFP32Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("transpose_fp32", code);
|
||||
TransposeFP32Runtime::launch(runtime, args);
|
||||
}
|
||||
return (dim == 2) ? out.squeeze(0) : out;
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) {
|
||||
@@ -127,7 +178,6 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
|
||||
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
|
||||
// Launch the kernel
|
||||
if (batched_sf.is_contiguous()) {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if ((mn * sf_k) % 4 != 0 and num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
@@ -146,11 +196,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
|
||||
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
|
||||
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
} else {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if (mn % 4 != 0 or num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
|
||||
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
|
||||
|
||||
constexpr int block_mn = 128;
|
||||
|
||||
@@ -1,412 +1,19 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "jit/compiler.hpp"
|
||||
#include "jit/device_runtime.hpp"
|
||||
#include "utils/layout.hpp"
|
||||
|
||||
#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/smxx_layout.hpp"
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME deep_gemm_cpp
|
||||
#endif
|
||||
|
||||
namespace deep_gemm {
|
||||
torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
const auto& gran_k = std::get<2>(recipe);
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// Pre-transform checks
|
||||
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
|
||||
|
||||
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
|
||||
}
|
||||
|
||||
// (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
|
||||
|
||||
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
|
||||
}
|
||||
|
||||
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
}
|
||||
|
||||
torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::tuple<int, int, int>& recipe) {
|
||||
DG_HOST_ASSERT(sf.dim() == 2);
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// FP32 on SM90
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
// FP32 on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
|
||||
|
||||
// INT on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
}
|
||||
|
||||
void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
if (fp8_requires_k_major()) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
}
|
||||
|
||||
// C/D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check C as well
|
||||
if (c.has_value()) {
|
||||
check_major_type_cd(c.value());
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
||||
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
if (fp8_requires_k_major())
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a.first);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Do nothing if empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
||||
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& expected_m,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a.first);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_m_grouped_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Must be 1D1D kernel
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
DG_HOST_ASSERT(b.first.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().is_contiguous());
|
||||
}
|
||||
|
||||
// Do nothing if empty
|
||||
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
|
||||
return;
|
||||
|
||||
// Transform SF with padding
|
||||
const auto& [_, m] = get_shape<2>(a.first);
|
||||
const auto& [__, n] = get_shape<2>(b.first);
|
||||
const auto& sfa = transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
||||
const auto& sfb = transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 10) {
|
||||
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
// ReSharper disable once CppParameterMayBeConstPtrOrRef
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
using namespace deep_gemm;
|
||||
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
// Runtime
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
});
|
||||
m.def("get_num_sms", [&]() {
|
||||
return device_runtime->get_num_sms();
|
||||
});
|
||||
m.def("set_tc_util", [&](const int& new_tc_util) {
|
||||
device_runtime->set_tc_util(new_tc_util);
|
||||
});
|
||||
m.def("get_tc_util", [&]() {
|
||||
return device_runtime->get_tc_util();
|
||||
});
|
||||
|
||||
// JIT
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) {
|
||||
Compiler::prepare_init(library_root_path, cuda_home_path_by_torch);
|
||||
KernelRuntime::prepare_init(cuda_home_path_by_torch);
|
||||
});
|
||||
|
||||
// Stable kernel APIs with automatic arch/layout dispatch
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_nn", &fp8_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tn", &fp8_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tt", &fp8_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_m_grouped_gemm_nt_masked", &fp8_m_grouped_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
|
||||
// Layout kernels
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
|
||||
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
|
||||
// Raw kernels or functions
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
|
||||
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
|
||||
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -10,7 +11,7 @@ class DGException final : public std::exception {
|
||||
|
||||
public:
|
||||
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
|
||||
message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error;
|
||||
}
|
||||
|
||||
const char *what() const noexcept override {
|
||||
@@ -50,7 +51,11 @@ do { \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != CUDA_SUCCESS) { \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ""); \
|
||||
std::stringstream ss; \
|
||||
const char *name, *info; \
|
||||
cuGetErrorName(e, &name), cuGetErrorString(e, &info); \
|
||||
ss << static_cast<int>(e) << " (" << name << ", " << info << ")"; \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
@@ -60,7 +65,9 @@ do { \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != cudaSuccess) { \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast<int>(e))); \
|
||||
std::stringstream ss; \
|
||||
ss << static_cast<int>(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user