Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100 - Refactor Python APIs - Other fixes and code refactoring
This commit is contained in:
@@ -105,7 +105,7 @@ We also provide a K-axis-grouped API for MoE weight backward (with M and N must
|
||||
|
||||
During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.
|
||||
|
||||
Use `fp8_m_grouped_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input.
|
||||
Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input.
|
||||
|
||||
#### Utilities
|
||||
|
||||
|
||||
471
csrc/apis/gemm.hpp
Normal file
471
csrc/apis/gemm.hpp
Normal file
@@ -0,0 +1,471 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
|
||||
|
||||
#include "layout.hpp"
|
||||
|
||||
namespace deep_gemm::gemm {
|
||||
|
||||
static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
if (fp8_requires_k_major()) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
}
|
||||
|
||||
// C/D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check C as well
|
||||
if (c.has_value()) {
|
||||
check_major_type_cd(c.value());
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
||||
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
if (fp8_requires_k_major())
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a.first);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Do nothing if empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
||||
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& expected_m,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a.first);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
|
||||
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Must be 1D1D kernel
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
DG_HOST_ASSERT(b.first.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().is_contiguous());
|
||||
}
|
||||
|
||||
// Do nothing if empty
|
||||
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
|
||||
return;
|
||||
|
||||
// Transform SF with padding
|
||||
const auto& [_, m] = get_shape<2>(a.first);
|
||||
const auto& [__, n] = get_shape<2>(b.first);
|
||||
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
||||
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 10) {
|
||||
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bf16_gemm_nt(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
|
||||
// C/D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a);
|
||||
const auto& [n , k_] = get_shape<2>(b);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check C as well
|
||||
if (c.has_value()) {
|
||||
check_major_type_cd(c.value());
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10) {
|
||||
sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void bf16_gemm_nn(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims);
|
||||
}
|
||||
|
||||
static void bf16_gemm_tn(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims);
|
||||
}
|
||||
|
||||
static void bf16_gemm_tt(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::string& compiled_dims) {
|
||||
bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims);
|
||||
}
|
||||
|
||||
static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const torch::Tensor& m_indices,
|
||||
const std::string& compiled_dims) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Do nothing if empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& d, const torch::Tensor& masked_m,
|
||||
const int& expected_m, const std::string& compiled_dims) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a);
|
||||
const auto& major_b = get_major_type_ab(b);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
// FP8 GEMMs
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_nn", &fp8_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tn", &fp8_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tt", &fp8_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
|
||||
// BF16 GEMMs
|
||||
m.def("bf16_gemm_nt", &bf16_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk");
|
||||
m.def("bf16_gemm_nn", &bf16_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk");
|
||||
m.def("bf16_gemm_tn", &bf16_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("bf16_gemm_tt", &bf16_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("compiled_dims") = "nk");
|
||||
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::gemm
|
||||
85
csrc/apis/layout.hpp
Normal file
85
csrc/apis/layout.hpp
Normal file
@@ -0,0 +1,85 @@
|
||||
#pragma once
|
||||
|
||||
#include "../utils/layout.hpp"
|
||||
#include "../jit_kernels/impls/smxx_layout.hpp"
|
||||
|
||||
namespace deep_gemm::layout {
|
||||
|
||||
static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
const auto& gran_k = std::get<2>(recipe);
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// Pre-transform checks
|
||||
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
|
||||
|
||||
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
|
||||
}
|
||||
|
||||
// (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
|
||||
|
||||
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
|
||||
}
|
||||
|
||||
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
}
|
||||
|
||||
static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::tuple<int, int, int>& recipe) {
|
||||
DG_HOST_ASSERT(sf.dim() == 2);
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// FP32 on SM90
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
// FP32 on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
|
||||
|
||||
// INT on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
}
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
|
||||
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
|
||||
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
|
||||
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::layout
|
||||
28
csrc/apis/runtime.hpp
Normal file
28
csrc/apis/runtime.hpp
Normal file
@@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#include "../jit/compiler.hpp"
|
||||
#include "../jit/device_runtime.hpp"
|
||||
|
||||
namespace deep_gemm::runtime {
|
||||
|
||||
static void register_apis(pybind11::module_& m) {
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
});
|
||||
m.def("get_num_sms", [&]() {
|
||||
return device_runtime->get_num_sms();
|
||||
});
|
||||
m.def("set_tc_util", [&](const int& new_tc_util) {
|
||||
device_runtime->set_tc_util(new_tc_util);
|
||||
});
|
||||
m.def("get_tc_util", [&]() {
|
||||
return device_runtime->get_tc_util();
|
||||
});
|
||||
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
|
||||
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
|
||||
KernelRuntime::prepare_init(cuda_home_path_by_python);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace deep_gemm::runtime
|
||||
@@ -35,10 +35,10 @@ public:
|
||||
}
|
||||
|
||||
static void prepare_init(const std::string& library_root_path,
|
||||
const std::string& cuda_home_path_by_torch) {
|
||||
const std::string& cuda_home_path_by_python) {
|
||||
Compiler::library_root_path = library_root_path;
|
||||
Compiler::library_include_path = Compiler::library_root_path / "include";
|
||||
Compiler::cuda_home = cuda_home_path_by_torch;
|
||||
Compiler::cuda_home = cuda_home_path_by_python;
|
||||
Compiler::library_version = get_library_version();
|
||||
}
|
||||
|
||||
@@ -64,6 +64,8 @@ public:
|
||||
get_env<int>("DG_JIT_CPP_STANDARD", 20));
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
|
||||
flags += " --ptxas-options=--verbose";
|
||||
if (get_env("DG_JIT_WITH_LINEINFO", 0))
|
||||
flags += " -Xcompiler -rdynamic -lineinfo";
|
||||
}
|
||||
|
||||
virtual ~Compiler() = default;
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
#if CUDART_VERSION >= 12080 or defined(DG_JIT_USE_DRIVER_API)
|
||||
#if CUDART_VERSION >= 12080 and not defined(DG_JIT_USE_DRIVER_API)
|
||||
|
||||
// Use CUDA runtime API
|
||||
using LibraryHandle = cudaLibrary_t;
|
||||
|
||||
@@ -64,8 +64,8 @@ public:
|
||||
kernel = load_kernel(cubin_path, symbol_names[0], &library);
|
||||
}
|
||||
|
||||
static void prepare_init(const std::string& cuda_home_path_by_torch) {
|
||||
cuda_home = cuda_home_path_by_torch;
|
||||
static void prepare_init(const std::string& cuda_home_path_by_python) {
|
||||
cuda_home = cuda_home_path_by_python;
|
||||
}
|
||||
|
||||
static bool check_validity(const std::filesystem::path& dir_path) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/layout.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -146,7 +147,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const bool& with_accumulation, const int& num_sms) {
|
||||
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
|
||||
|
||||
// Select M/N block sizes
|
||||
@@ -179,7 +180,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
for (const auto& block_n: block_ns) {
|
||||
const int& num_waves = get_num_waves(block_m, block_n);
|
||||
const auto& last_util = get_last_wave_util(block_m, block_n);
|
||||
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n))
|
||||
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n, block_k))
|
||||
continue;
|
||||
|
||||
bool success = false;
|
||||
|
||||
@@ -43,7 +43,7 @@ struct SM100ArchSpec {
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n) {
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// TODO: consider more carefully for BF16 GEMMs
|
||||
// 2SM BF16 UMMA does not support `N % 32 != 0`
|
||||
if (ab_dtype == torch::kBFloat16 and block_n % 32 != 0)
|
||||
@@ -106,7 +106,7 @@ struct SM100ArchSpec {
|
||||
const int& swizzle_cd_mode,
|
||||
const at::ScalarType& cd_dtype) {
|
||||
constexpr static int layout_ad_m = 128;
|
||||
return (kernel_type == KernelType::Kernel1D1D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2;
|
||||
return (kernel_type != KernelType::Kernel1D2D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2;
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
|
||||
|
||||
@@ -30,15 +30,19 @@ struct SM90ArchSpec {
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n) {
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// FP32 output does not support `block_m == 256`
|
||||
if (cd_dtype == at::kFloat and block_m == 256)
|
||||
return false;
|
||||
|
||||
// TODO: more general block N selection
|
||||
// Must be some fixed block N selections
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 or block_n != 152))
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 and block_n != 152))
|
||||
return false;
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 or block_n != 160))
|
||||
|
||||
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
|
||||
// Or too many register spills
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
|
||||
return false;
|
||||
|
||||
// Avoid bank conflicts for FP32 output
|
||||
|
||||
143
csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Normal file
143
csrc/jit_kernels/impls/sm100_bf16_gemm.hpp
Normal file
@@ -0,0 +1,143 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100BF16GemmRuntime final: public LaunchRuntime<SM100BF16GemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void* grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_c;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_bf16_gemm_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
|
||||
args.gemm_config.tc_util);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.grouped_layout, args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_c, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_bf16_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
// TODO: test other Ks
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& cd = c.value_or(d);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(cd.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
if (c->data_ptr() == d.data_ptr()) {
|
||||
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
} else {
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
d.copy_(c.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Launch
|
||||
const SM100BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_c = tensor_map_c,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_bf16_gemm", code);
|
||||
SM100BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -155,7 +156,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -202,7 +203,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -136,7 +137,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -179,7 +180,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
229
csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Normal file
229
csrc/jit_kernels/impls/sm90_bf16_gemm.hpp
Normal file
@@ -0,0 +1,229 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90BF16GemmRuntime final: public LaunchRuntime<SM90BF16GemmRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_bf16_gemm_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
// TODO: add CD dtype
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_bf16_gemm(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bf16_gemm", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(k % 64 == 0);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::KernelNoSF,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kBFloat16, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
|
||||
// Launch
|
||||
const SM90BF16GemmRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
};
|
||||
const auto& code = SM90BF16GemmRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
|
||||
SM90BF16GemmRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -3,6 +3,7 @@
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/device_runtime.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
@@ -139,7 +140,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
@@ -185,7 +186,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
|
||||
@@ -10,6 +10,35 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class TransposeFP32Runtime final: public LaunchRuntime<TransposeFP32Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
int mn, sf_k;
|
||||
int block_mn;
|
||||
void *sf, *out;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&transpose_fp32<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast<uint32_t>(args.mn)));
|
||||
}
|
||||
};
|
||||
|
||||
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
@@ -88,10 +117,32 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
|
||||
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
|
||||
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
|
||||
|
||||
// Normal layout requires transposing
|
||||
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
|
||||
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
|
||||
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
|
||||
const auto& out = torch::empty_strided({num_groups, mn, sf_k},
|
||||
{tma_aligned_mn * sf_k, 1, tma_aligned_mn},
|
||||
batched_sf.options());
|
||||
|
||||
if (not batched_sf.is_contiguous()) {
|
||||
// Fallback to PyTorch's slow copy if not contiguous
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
out.copy_(batched_sf);
|
||||
} else {
|
||||
constexpr int block_mn = 64;
|
||||
constexpr int num_threads = 512;
|
||||
const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast<int>(sizeof(float));
|
||||
const TransposeFP32Runtime::Args& args = {
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.block_mn = block_mn,
|
||||
.sf = batched_sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size)
|
||||
};
|
||||
|
||||
const auto& code = TransposeFP32Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("transpose_fp32", code);
|
||||
TransposeFP32Runtime::launch(runtime, args);
|
||||
}
|
||||
return (dim == 2) ? out.squeeze(0) : out;
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) {
|
||||
@@ -127,7 +178,6 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
|
||||
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
|
||||
// Launch the kernel
|
||||
if (batched_sf.is_contiguous()) {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if ((mn * sf_k) % 4 != 0 and num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
@@ -146,11 +196,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
|
||||
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
|
||||
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
} else {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if (mn % 4 != 0 or num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
|
||||
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
|
||||
|
||||
constexpr int block_mn = 128;
|
||||
|
||||
@@ -1,412 +1,19 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "jit/compiler.hpp"
|
||||
#include "jit/device_runtime.hpp"
|
||||
#include "utils/layout.hpp"
|
||||
|
||||
#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/smxx_layout.hpp"
|
||||
#include "apis/gemm.hpp"
|
||||
#include "apis/layout.hpp"
|
||||
#include "apis/runtime.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME deep_gemm_cpp
|
||||
#endif
|
||||
|
||||
namespace deep_gemm {
|
||||
torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
const auto& gran_k = std::get<2>(recipe);
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// Pre-transform checks
|
||||
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
|
||||
|
||||
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
|
||||
}
|
||||
|
||||
// (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
|
||||
|
||||
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
|
||||
}
|
||||
|
||||
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
}
|
||||
|
||||
torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::tuple<int, int, int>& recipe) {
|
||||
DG_HOST_ASSERT(sf.dim() == 2);
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// FP32 on SM90
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
// FP32 on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
|
||||
|
||||
// INT on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
}
|
||||
|
||||
void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
if (fp8_requires_k_major()) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
}
|
||||
|
||||
// C/D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check C as well
|
||||
if (c.has_value()) {
|
||||
check_major_type_cd(c.value());
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
||||
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
if (fp8_requires_k_major())
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a.first);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Do nothing if empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
||||
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& expected_m,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a.first);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_m_grouped_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Must be 1D1D kernel
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
DG_HOST_ASSERT(b.first.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().is_contiguous());
|
||||
}
|
||||
|
||||
// Do nothing if empty
|
||||
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
|
||||
return;
|
||||
|
||||
// Transform SF with padding
|
||||
const auto& [_, m] = get_shape<2>(a.first);
|
||||
const auto& [__, n] = get_shape<2>(b.first);
|
||||
const auto& sfa = transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
||||
const auto& sfb = transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 10) {
|
||||
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
// ReSharper disable once CppParameterMayBeConstPtrOrRef
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
using namespace deep_gemm;
|
||||
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
// Runtime
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
});
|
||||
m.def("get_num_sms", [&]() {
|
||||
return device_runtime->get_num_sms();
|
||||
});
|
||||
m.def("set_tc_util", [&](const int& new_tc_util) {
|
||||
device_runtime->set_tc_util(new_tc_util);
|
||||
});
|
||||
m.def("get_tc_util", [&]() {
|
||||
return device_runtime->get_tc_util();
|
||||
});
|
||||
|
||||
// JIT
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) {
|
||||
Compiler::prepare_init(library_root_path, cuda_home_path_by_torch);
|
||||
KernelRuntime::prepare_init(cuda_home_path_by_torch);
|
||||
});
|
||||
|
||||
// Stable kernel APIs with automatic arch/layout dispatch
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_nn", &fp8_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tn", &fp8_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tt", &fp8_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_m_grouped_gemm_nt_masked", &fp8_m_grouped_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
|
||||
// Layout kernels
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
|
||||
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
|
||||
// Raw kernels or functions
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
|
||||
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
|
||||
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
deep_gemm::gemm::register_apis(m);
|
||||
deep_gemm::layout::register_apis(m);
|
||||
deep_gemm::runtime::register_apis(m);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -10,7 +11,7 @@ class DGException final : public std::exception {
|
||||
|
||||
public:
|
||||
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
|
||||
message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error;
|
||||
}
|
||||
|
||||
const char *what() const noexcept override {
|
||||
@@ -50,7 +51,11 @@ do { \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != CUDA_SUCCESS) { \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ""); \
|
||||
std::stringstream ss; \
|
||||
const char *name, *info; \
|
||||
cuGetErrorName(e, &name), cuGetErrorString(e, &info); \
|
||||
ss << static_cast<int>(e) << " (" << name << ", " << info << ")"; \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
@@ -60,7 +65,9 @@ do { \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != cudaSuccess) { \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast<int>(e))); \
|
||||
std::stringstream ss; \
|
||||
ss << static_cast<int>(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
import subprocess
|
||||
|
||||
# Set some default environment provided at setup
|
||||
try:
|
||||
@@ -12,14 +11,8 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Import functions from the CPP module
|
||||
import deep_gemm_cpp
|
||||
deep_gemm_cpp.init(
|
||||
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
|
||||
torch.utils.cpp_extension.CUDA_HOME # CUDA home
|
||||
)
|
||||
|
||||
# Configs
|
||||
import deep_gemm_cpp
|
||||
from deep_gemm_cpp import (
|
||||
set_num_sms,
|
||||
get_num_sms,
|
||||
@@ -34,13 +27,48 @@ from deep_gemm_cpp import (
|
||||
fp8_gemm_tn, fp8_gemm_tt,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
m_grouped_fp8_gemm_nn_contiguous,
|
||||
fp8_m_grouped_gemm_nt_masked,
|
||||
m_grouped_fp8_gemm_nt_masked,
|
||||
k_grouped_fp8_gemm_tn_contiguous,
|
||||
# BF16 GEMMs
|
||||
bf16_gemm_nt, bf16_gemm_nn,
|
||||
bf16_gemm_tn, bf16_gemm_tt,
|
||||
m_grouped_bf16_gemm_nt_contiguous,
|
||||
m_grouped_bf16_gemm_nt_masked,
|
||||
# Layout kernels
|
||||
transform_sf_into_required_layout
|
||||
)
|
||||
|
||||
# Some alias for legacy supports
|
||||
# TODO: remove these later
|
||||
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked
|
||||
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
|
||||
|
||||
# Some utils
|
||||
from . import testing
|
||||
from . import utils
|
||||
from .utils import *
|
||||
|
||||
|
||||
# Initialize CPP modules
|
||||
def _find_cuda_home() -> str:
|
||||
# TODO: reuse PyTorch API later
|
||||
# For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks
|
||||
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
|
||||
if cuda_home is None:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
with open(os.devnull, 'w') as devnull:
|
||||
nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n')
|
||||
cuda_home = os.path.dirname(os.path.dirname(nvcc))
|
||||
except Exception:
|
||||
cuda_home = '/usr/local/cuda'
|
||||
if not os.path.exists(cuda_home):
|
||||
cuda_home = None
|
||||
assert cuda_home is not None
|
||||
return cuda_home
|
||||
|
||||
|
||||
deep_gemm_cpp.init(
|
||||
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
|
||||
_find_cuda_home() # CUDA home
|
||||
)
|
||||
|
||||
@@ -11,16 +11,22 @@ enum class KGroupedIndexType {
|
||||
SF_K,
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool isMulticastOnA>
|
||||
template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
|
||||
static constexpr uint32_t get_num_1d_blocks_per_group() {
|
||||
// Select the best from candidates
|
||||
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
|
||||
for (const auto& candidate: {8u, 16u}) {
|
||||
const auto& usage = isMulticastOnA ?
|
||||
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
||||
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
if constexpr (kGemmType == GemmType::MGroupedContiguous or
|
||||
kGemmType == GemmType::MGroupedMasked) {
|
||||
// For grouped GEMMs, let weights always stay in the L2 cache and read activations by once
|
||||
num_best_blocks = kNumSMs;
|
||||
} else {
|
||||
for (const auto& candidate: {8u, 16u}) {
|
||||
const auto& usage = kIsMulticastOnA ?
|
||||
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
||||
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
}
|
||||
}
|
||||
return num_best_blocks;
|
||||
}
|
||||
@@ -32,7 +38,7 @@ template <GemmType kGemmType,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
|
||||
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
|
||||
|
||||
@@ -48,7 +48,18 @@ struct FP8MMASelector {
|
||||
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN();
|
||||
if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN();
|
||||
}
|
||||
|
||||
static constexpr auto select_type() {
|
||||
@@ -58,6 +69,71 @@ struct FP8MMASelector {
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
template <int N_, typename MMA>
|
||||
struct BF16MMA {
|
||||
|
||||
template <size_t ...Idx>
|
||||
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
|
||||
using namespace cute::SM90::GMMA;
|
||||
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
static constexpr int N = N_;
|
||||
static constexpr int K = 16;
|
||||
static constexpr int kNumAccum = M * N / 128;
|
||||
};
|
||||
|
||||
template <int N>
|
||||
struct BF16MMASelector {
|
||||
|
||||
static constexpr auto select_mma() {
|
||||
using namespace cute::SM90::GMMA;
|
||||
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS<Major::K, Major::K>();
|
||||
}
|
||||
|
||||
static constexpr auto select_type() {
|
||||
return BF16MMA<N, decltype(select_mma())>();
|
||||
}
|
||||
|
||||
using type = decltype(select_type());
|
||||
};
|
||||
|
||||
|
||||
template <typename dtype_t>
|
||||
struct SM90_U32x2_STSM_N {
|
||||
__device__ __forceinline__ static void
|
||||
|
||||
@@ -144,4 +144,22 @@ __device__ __forceinline__ void prefetch_l1(void *ptr) {
|
||||
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
|
||||
}
|
||||
|
||||
template <uint32_t kNumBytes>
|
||||
struct Vectorized {
|
||||
static auto zeros() {
|
||||
// TODO: add `ulonglong4` for SM100 once `__ldg` support this
|
||||
if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) {
|
||||
return make_uint4(0, 0, 0, 0);
|
||||
} else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) {
|
||||
return make_uint2(0, 0);
|
||||
} else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) {
|
||||
return 0;
|
||||
} else {
|
||||
DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization");
|
||||
}
|
||||
}
|
||||
|
||||
using vec_t = decltype(zeros());
|
||||
};
|
||||
|
||||
} // namespace `deep_gemm`
|
||||
|
||||
@@ -1,3 +1,498 @@
|
||||
#pragma once
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
// TODO: add implement
|
||||
#include <cutlass/arch/barrier.h>
|
||||
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/sm100_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm100;
|
||||
|
||||
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
|
||||
uint64_t kTensorCoreUtilControl>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_bf16_gemm_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_c,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// GEMM with accumulation must have FP32 output
|
||||
if constexpr (kWithAccumulation)
|
||||
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t LAYOUT_AD_M = 128;
|
||||
constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M;
|
||||
constexpr uint32_t kNumTMAStoreStages = 2;
|
||||
DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
|
||||
DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
|
||||
// Utils
|
||||
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
|
||||
const auto warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
const auto lane_idx = get_lane_idx();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
|
||||
// 2-CTA MMA
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
||||
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
|
||||
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
|
||||
|
||||
// Share memory sizes
|
||||
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
|
||||
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
|
||||
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
||||
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
|
||||
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
|
||||
|
||||
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
|
||||
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
|
||||
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2;
|
||||
|
||||
// Real tensor memory size and offsets
|
||||
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
|
||||
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
if constexpr (kWithAccumulation)
|
||||
cute::prefetch_tma_descriptor(&tensor_map_c);
|
||||
}
|
||||
|
||||
// Data on shared memory (layout as ordered below)
|
||||
cd_dtype_t* smem_cd[kNumTMAStoreStages];
|
||||
cutlass::bfloat16_t* smem_a[kNumStages];
|
||||
cutlass::bfloat16_t* smem_b[kNumStages];
|
||||
|
||||
// Fill D/A/B pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
|
||||
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
|
||||
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
|
||||
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
|
||||
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
|
||||
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
|
||||
|
||||
// Fill the tensor memory pointer
|
||||
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
|
||||
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == 0) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive only at the leader CTA
|
||||
full_barriers[i]->init(kNumMulticast);
|
||||
// Arrive at all CTAs
|
||||
empty_barriers[i]->init(1);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
tmem_full_barriers[i]->init(1);
|
||||
// Arrive only at the leader CTA
|
||||
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
|
||||
// Allocate tensor memory
|
||||
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
|
||||
}
|
||||
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
uint32_t phase = 0;
|
||||
auto launch_k_iterations = [&](const auto& func) {
|
||||
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
|
||||
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
|
||||
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
|
||||
|
||||
// TODO: refactor here
|
||||
if (num_last_stages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
|
||||
func(k_iter, DivisibleK{}, false, num_last_stages);
|
||||
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
|
||||
}
|
||||
};
|
||||
|
||||
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
|
||||
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
|
||||
"Too many epilogue stages, please modify the Python heuristic as well");
|
||||
accum_stage_idx == 0 ? func(0) : func(1);
|
||||
};
|
||||
|
||||
// Dispatch warps into different roles
|
||||
if (warp_idx == 0) {
|
||||
// TMA load warp
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
|
||||
// Compute offsets
|
||||
// NOTES: the group is always concatenated with the outer dimension
|
||||
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
|
||||
shape_m, BLOCK_M, m_block_idx);
|
||||
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
|
||||
shape_n, BLOCK_N, n_block_idx, m_block_idx);
|
||||
|
||||
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
|
||||
// And for all m-grouped GEMMs, A must be K-majored
|
||||
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
|
||||
uint32_t k_block_idx = k_iter * kNumStages + s;
|
||||
uint32_t k_idx = k_block_idx * BLOCK_K;
|
||||
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
|
||||
shape_k, BLOCK_K, k_block_idx, m_block_idx);
|
||||
|
||||
// Add 2 CTA offsets
|
||||
if constexpr (kNumMulticast > 1) {
|
||||
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
|
||||
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
|
||||
}
|
||||
|
||||
// Issue TMAs
|
||||
if (cute::elect_one_sync()) {
|
||||
if constexpr (kMajorA == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
|
||||
if constexpr (kMajorA == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::K)
|
||||
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
|
||||
if constexpr (kMajorB == cute::UMMA::Major::MN)
|
||||
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
|
||||
}
|
||||
// Arrive at full barriers
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait(phase ^ 1);
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (warp_idx == 1 and is_leader_cta) {
|
||||
// MMA issue warp
|
||||
// NOTES: only the leader CTA will do this
|
||||
// Make instruction descriptor
|
||||
// TODO: refactor `UMMA_M` calculation
|
||||
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
|
||||
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
|
||||
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>();
|
||||
|
||||
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
|
||||
auto a_desc = make_umma_desc<kMajorA, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
|
||||
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
|
||||
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
|
||||
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
|
||||
|
||||
// Checks for MMA instructions
|
||||
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
|
||||
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
|
||||
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
|
||||
"Invalid MMA instruction shape");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
// Wait tensor memory empty barrier arrival
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
|
||||
auto umma_arrive = [](const uint64_t* barrier) {
|
||||
if constexpr (kNumMulticast == 1) {
|
||||
cutlass::arch::umma_arrive(barrier);
|
||||
} else {
|
||||
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
|
||||
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
|
||||
}
|
||||
};
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
|
||||
|
||||
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
|
||||
if (do_tmem_full_arrive)
|
||||
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
|
||||
};
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrival
|
||||
full_barriers[s]->wait(phase);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Let tensor cores relax for lower possibility of frequency drop
|
||||
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
|
||||
if constexpr (kTensorCoreUtilControl < 100) {
|
||||
constexpr static uint64_t kNumUMMACycles = (2ull * BLOCK_M * BLOCK_N * BLOCK_K) / 8192ull;
|
||||
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
|
||||
const auto& start_clock = clock64();
|
||||
if (cute::elect_one_sync())
|
||||
while (clock64() - start_clock < kNumDummyCycles) {}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_F16BF16_SS <cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>,
|
||||
cute::SM100_MMA_F16BF16_2x1SM_SS<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
|
||||
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
|
||||
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
|
||||
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
|
||||
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
|
||||
cute_mma_t::fma(a_desc, b_desc,
|
||||
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
|
||||
k_iter > 0 or s > 0 or k > 0,
|
||||
runtime_instr_desc);
|
||||
}
|
||||
}
|
||||
|
||||
// Commit to the mbarrier object
|
||||
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
|
||||
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait(phase);
|
||||
empty_barrier_arrive(s, false);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
|
||||
// Epilogue warp groups
|
||||
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
|
||||
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
|
||||
|
||||
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
|
||||
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
|
||||
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
|
||||
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
|
||||
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
|
||||
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
|
||||
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
|
||||
|
||||
// Flush TMA stores
|
||||
// NOTES: for the first store, we have to flush all previous TMA,
|
||||
// as we don't share pipeline stages between two blocks
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
|
||||
// Wait UMMA arrival
|
||||
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
|
||||
tcgen05_after_thread_sync();
|
||||
|
||||
// Load from tensor memory into registers, and write shared memory with STSM
|
||||
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
|
||||
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
|
||||
|
||||
// Iterate over M waves
|
||||
#pragma unroll
|
||||
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
|
||||
// Issue every swizzled atom and pipeline STSM and TMA store
|
||||
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStores; ++ s) {
|
||||
// Wait shared memory to be released
|
||||
const uint32_t iter_idx = w * kNumStores + s;
|
||||
if (iter_idx >= kNumTMAStoreStages) {
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<kNumTMAStoreStages - 1>();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
}
|
||||
|
||||
// The pipeline stage
|
||||
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
|
||||
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
|
||||
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
|
||||
|
||||
// Store into shared memory
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
|
||||
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
|
||||
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
|
||||
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleCDMode / 16);
|
||||
|
||||
// Source and destination memory address
|
||||
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
|
||||
w * BLOCK_N + // Wave offset
|
||||
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
|
||||
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
|
||||
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
cutlass::arch::fence_view_async_tmem_load();
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(values[0], values[1]),
|
||||
cast_into_bf16_and_pack(values[2], values[3]),
|
||||
cast_into_bf16_and_pack(values[4], values[5]),
|
||||
cast_into_bf16_and_pack(values[6], values[7]));
|
||||
}
|
||||
}
|
||||
|
||||
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
|
||||
// NOTES: only the last stage needs to do this
|
||||
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
|
||||
tcgen05_before_thread_sync();
|
||||
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Synchronize all threads and issue TMA
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
if (epilogue_thread_idx == 0) {
|
||||
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
|
||||
// TODO: do we actually need this?
|
||||
if (epilogue_thread_idx == 0)
|
||||
cute::tma_store_wait<0>();
|
||||
|
||||
// Deallocate tensor memory by warp 1
|
||||
// NOTES: warp 0 is waiting TMA store
|
||||
// TODO: do we need 2 SM allocation?
|
||||
if (epilogue_warp_idx == 1)
|
||||
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
|
||||
}
|
||||
|
||||
// To safely deconstruct all barriers, we need a cluster sync
|
||||
// TODO: optimize it by another round of barrier waits
|
||||
if constexpr (kNumMulticast > 1)
|
||||
cute::cluster_sync();
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -136,7 +136,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
// Arrive at all CTAs
|
||||
full_barriers[i]->init(1);
|
||||
full_barriers[i]->init(kNumMulticast);
|
||||
empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32);
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -241,6 +241,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE;
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
@@ -249,6 +251,8 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
if (is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive();
|
||||
if (not is_leader_cta and cute::elect_one_sync())
|
||||
full_barriers[s]->arrive(0u);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,3 +1,343 @@
|
||||
#pragma once
|
||||
|
||||
// TODO: add implement
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wunknown-attributes"
|
||||
|
||||
#include <cutlass/arch/barrier.h>
|
||||
#include <cutlass/arch/reg_reconfig.h>
|
||||
|
||||
#include <cute/arch/cluster_sm90.hpp>
|
||||
#include <cute/arch/copy_sm90_desc.hpp>
|
||||
#include <cute/arch/copy_sm90_tma.hpp>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
#include <deep_gemm/common/scheduler.cuh>
|
||||
#include <deep_gemm/common/sm90_utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
using namespace deep_gemm::sm90;
|
||||
|
||||
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
|
||||
uint32_t kSwizzleDMode,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
uint32_t kNumSMs, GemmType kGemmType>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
||||
sm90_bf16_gemm_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Types
|
||||
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
|
||||
|
||||
// Overwrite shape constants if the compiler gives
|
||||
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
|
||||
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
|
||||
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
|
||||
|
||||
// Shared memory
|
||||
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
|
||||
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
|
||||
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const uint32_t lane_idx = get_lane_idx();
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
// Align to 1024 bytes for swizzle-128B
|
||||
extern __shared__ __align__(1024) uint8_t smem_buffer[];
|
||||
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
|
||||
|
||||
// Data on shared memory
|
||||
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
|
||||
__nv_bfloat16* smem_a[kNumStages];
|
||||
__nv_bfloat16* smem_b[kNumStages];
|
||||
|
||||
// TMA Barrier for both divisible and non-divisible cases
|
||||
Barrier* full_barriers[kNumStages];
|
||||
Barrier* empty_barriers[kNumStages];
|
||||
|
||||
// Fill shared memory pointers
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
smem_a[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
|
||||
smem_b[i] = reinterpret_cast<__nv_bfloat16*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
// Fill barriers
|
||||
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i] = barrier_start_ptr + i;
|
||||
empty_barriers[i] = barrier_start_ptr + kNumStages + i;
|
||||
}
|
||||
|
||||
// Initialize barriers
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < kNumStages; ++ i) {
|
||||
full_barriers[i]->init(1);
|
||||
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
|
||||
}
|
||||
|
||||
// Make initialized barrier visible in async proxy
|
||||
cutlass::arch::fence_view_async_shared();
|
||||
cutlass::arch::fence_barrier_init();
|
||||
}
|
||||
|
||||
// Synchronize all threads to make barrier visible in normal memory model
|
||||
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
|
||||
|
||||
struct DivisibleK {};
|
||||
struct NotDivisibleK {};
|
||||
auto launch_k_iterations = [=](const auto& func) {
|
||||
if constexpr (kNumLastStages == 0) {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
} else {
|
||||
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
|
||||
func(k_iter, DivisibleK{});
|
||||
func(num_iterations - 1, NotDivisibleK{});
|
||||
}
|
||||
};
|
||||
|
||||
// Register reconfigurations
|
||||
constexpr uint32_t kNumTMARegisters = 48;
|
||||
constexpr uint32_t kNumMathRegisters = 224;
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
|
||||
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
|
||||
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
|
||||
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
|
||||
|
||||
// NOTES: unrolling and `kNumInnerStages` are vital for performance, NVCC will try to eliminate all
|
||||
// shared memory pointers, e.g. `full_barriers` registers, if all the access indices are constant
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait consumer release
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
|
||||
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
|
||||
auto& full_barrier = *full_barriers[s];
|
||||
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
|
||||
|
||||
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_a[s], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
|
||||
num_tma_multicast_a);
|
||||
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
|
||||
smem_b[s], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
|
||||
num_tma_multicast_b);
|
||||
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
|
||||
full_barriers[s]->arrive();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// To safely deconstruct distributed shared barriers, we need another round of empty waits
|
||||
if constexpr (kNumTMAMulticast > 1) {
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumStages; ++ s)
|
||||
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Math warp-groups for WGMMA
|
||||
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
|
||||
|
||||
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
|
||||
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
|
||||
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2);
|
||||
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
|
||||
float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
|
||||
|
||||
// Empty barrier arrival
|
||||
auto empty_barrier_arrive = [&](uint32_t s) {
|
||||
if constexpr (kNumTMAMulticast == 1) {
|
||||
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
|
||||
} else {
|
||||
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
|
||||
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
|
||||
}
|
||||
};
|
||||
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type) {
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
|
||||
// TODO: remove some useless computation for unaligned Ms
|
||||
#pragma unroll
|
||||
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
|
||||
// Wait TMA arrivals
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
||||
|
||||
// Commit WGMMA instructions
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_arrive();
|
||||
#pragma unroll
|
||||
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
|
||||
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
|
||||
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
|
||||
WGMMA::wgmma(desc_a, desc_b, shifted_accum, 1);
|
||||
}
|
||||
warpgroup_commit_batch();
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
|
||||
warpgroup_fence_operand(accum[i]);
|
||||
warpgroup_wait<0>();
|
||||
|
||||
// Notify barrier arrival at the last warpgroup wave
|
||||
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
}
|
||||
|
||||
// Wait unaligned cases
|
||||
#pragma unroll
|
||||
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
|
||||
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
|
||||
empty_barrier_arrive(s);
|
||||
}
|
||||
});
|
||||
|
||||
// TMA checks
|
||||
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
|
||||
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
|
||||
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
|
||||
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
|
||||
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
|
||||
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
|
||||
"Unaligned TMA store or too many TMA store instructions");
|
||||
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
|
||||
|
||||
// Wait last TMA store to be finished
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
|
||||
cute::tma_store_wait<0>();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Write back to shared memory using STSM and issue TMA stores
|
||||
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
|
||||
#pragma unroll
|
||||
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
|
||||
auto m_offset = local_idx * WAVE_BLOCK_M;
|
||||
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
|
||||
#pragma unroll
|
||||
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
|
||||
// Swizzle or padding into the correct address
|
||||
uint8_t* smem_ptr = nullptr;
|
||||
if constexpr (kSwizzleDMode > 0) {
|
||||
// Calculate the swizzling atom offset and in-atom offset
|
||||
constexpr uint32_t kNumBankGroupBytes = 16;
|
||||
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
|
||||
|
||||
// Calculate the index of the bank group to be written in the atom
|
||||
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
|
||||
|
||||
// Reshape the atom in another view and swizzle
|
||||
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
|
||||
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
|
||||
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
|
||||
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
|
||||
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
|
||||
col ^= row % (kSwizzleDMode / 16);
|
||||
|
||||
// Add back into the base pointer
|
||||
// NOTES: think twice before modifying this, as changes may affect the number of instructions
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
|
||||
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
|
||||
m_offset * kSwizzleDMode + // Wave offset
|
||||
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
|
||||
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
|
||||
} else {
|
||||
// No swizzling, just padding
|
||||
// TODO: support more cases
|
||||
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
|
||||
}
|
||||
|
||||
// NOTES: only 16 lanes' addresses are used
|
||||
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
|
||||
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
|
||||
smem_ptr
|
||||
);
|
||||
}
|
||||
}
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
|
||||
|
||||
// Use TMA store to write back to global memory
|
||||
// TODO: compatible with FP32 output
|
||||
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
|
||||
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
|
||||
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
|
||||
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
|
||||
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
|
||||
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
|
||||
n_block_idx * BLOCK_N + in_block_n_offset,
|
||||
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
|
||||
cute::tma_store_arrive();
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
}
|
||||
#else
|
||||
if (blockIdx.x == 0 and threadIdx.x == 0)
|
||||
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
|
||||
#endif
|
||||
}
|
||||
|
||||
}; // namespace deep_gemm
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
@@ -175,7 +175,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
|
||||
|
||||
// NOTES: only one thread (or warp) will be used
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
if (threadIdx.x < kNumMathThreads + 32 and cute::elect_one_sync()) {
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
|
||||
|
||||
@@ -4,6 +4,45 @@
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
|
||||
uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
|
||||
__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
|
||||
typedef typename Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
|
||||
constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
|
||||
constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
|
||||
|
||||
// Shapes and strides
|
||||
extern __shared__ float smem_buffer[];
|
||||
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
|
||||
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
|
||||
const auto tma_aligned_mn = align<uint32_t>(mn, kNumTMAAlignedElems);
|
||||
|
||||
// Shift into the block
|
||||
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
|
||||
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
|
||||
const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
|
||||
|
||||
// Load
|
||||
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
|
||||
auto in_vec = __ldg(local_sf + i);
|
||||
const auto& in_values = reinterpret_cast<float*>(&in_vec);
|
||||
|
||||
const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < kNumElemsPerVec; ++ j)
|
||||
smem_buffer[row * PADDED_SF_K + col + j] = in_values[j];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Store
|
||||
#pragma unroll
|
||||
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
|
||||
const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
|
||||
const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
|
||||
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// NOTES: the two kernels below always pack the K dimension
|
||||
|
||||
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["torch>=2.1.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
4
setup.py
4
setup.py
@@ -27,6 +27,10 @@ third_party_include_dirs = [
|
||||
'third-party/cutlass/include/cutlass',
|
||||
]
|
||||
|
||||
# Use driver API for older CUDA compatibility
|
||||
if int(os.environ.get('DG_JIT_USE_DRIVER_API', '0')):
|
||||
cxx_flags.append('-DDG_JIT_USE_DRIVER_API')
|
||||
|
||||
|
||||
class CustomBuildPy(build_py):
|
||||
def run(self):
|
||||
|
||||
@@ -59,6 +59,7 @@ def get_out_dtype() -> tuple:
|
||||
|
||||
|
||||
def get_major_ab(freeze_a: bool) -> tuple:
|
||||
# TODO: test other major-ness for SM90 BF16 GEMMs
|
||||
if get_arch_major() == 9:
|
||||
return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), )
|
||||
if freeze_a:
|
||||
@@ -70,15 +71,15 @@ def get_major_ab(freeze_a: bool) -> tuple:
|
||||
def enumerate_normal(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048), (129280, 7168)]:
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
|
||||
for major_a, major_b in get_major_ab(False):
|
||||
for out_dtype in get_out_dtype():
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or not kernel_type.is_1d1d() else (False, True):
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
|
||||
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
|
||||
|
||||
|
||||
def enumerate_m_grouped_contiguous() -> Generator:
|
||||
for kernel_type in get_kernel_types():
|
||||
def enumerate_m_grouped_contiguous(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
|
||||
for major_a, major_b in get_major_ab(True):
|
||||
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
|
||||
@@ -106,15 +107,12 @@ def enumerate_k_grouped_contiguous():
|
||||
|
||||
|
||||
def enumerate_sf_layout():
|
||||
for with_transpose in (True, False):
|
||||
for mn in (4096, 4097, 8192):
|
||||
for k in (128, 7168, 7296):
|
||||
for num_groups in (1, 2, 4) if with_transpose else (1, ):
|
||||
if num_groups > 1 and (mn * ceil_div(k, 128)) % 4 != 0:
|
||||
continue
|
||||
if not with_transpose and mn % 4 != 0:
|
||||
continue
|
||||
yield mn, k, with_transpose, num_groups
|
||||
for use_ue8m0 in (False, True):
|
||||
for with_transpose in (True, False):
|
||||
for mn in (4096, 4097, 8192):
|
||||
for k in (128, 7168, 7296):
|
||||
for num_groups in (1, 2, 4):
|
||||
yield mn, k, with_transpose, use_ue8m0, num_groups
|
||||
|
||||
|
||||
def enumerate_k_grouped_sf_layout():
|
||||
@@ -126,6 +124,13 @@ def enumerate_k_grouped_sf_layout():
|
||||
yield mn, ks, num_groups
|
||||
|
||||
|
||||
def enumerate_transpose():
|
||||
for mn in (64, 4096, 16384):
|
||||
for delta in (0, 101, 202, 303):
|
||||
for k in (128, 1024, 4096, 9984, 16384):
|
||||
yield mn + delta, k
|
||||
|
||||
|
||||
def generate_normal(m: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
accumulate: bool, out_dtype: torch.dtype,
|
||||
@@ -149,8 +154,8 @@ def generate_normal(m: int, n: int, k: int,
|
||||
|
||||
|
||||
def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB, use_ue8m0: bool) -> \
|
||||
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
|
||||
aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms]
|
||||
m = sum(aligned_ms)
|
||||
@@ -171,6 +176,10 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n:
|
||||
start = aligned_end
|
||||
ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d)
|
||||
|
||||
if use_bf16:
|
||||
b = b if major_b.is_k_major() else b.mT.contiguous().mT
|
||||
return m, a, b, m_indices, d, ref_d
|
||||
|
||||
assert major_a.is_k_major()
|
||||
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
|
||||
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn),
|
||||
@@ -181,24 +190,27 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n:
|
||||
return m, a_fp8, b_fp8, m_indices, d, ref_d
|
||||
|
||||
|
||||
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, use_ue8m0: bool) -> \
|
||||
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int,
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
|
||||
d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
|
||||
ref_d = torch.einsum('gmk,gnk->gmn', a, b)
|
||||
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||
assert masked_m.amax().item() <= max_m
|
||||
|
||||
if use_bf16:
|
||||
return a, b, masked_m, d, ref_d
|
||||
|
||||
a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float))
|
||||
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
|
||||
for i in range(num_groups):
|
||||
a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0)
|
||||
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
|
||||
|
||||
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
|
||||
for j in range(num_groups):
|
||||
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
||||
assert masked_m.amax().item() <= max_m
|
||||
|
||||
return a_fp8, b_fp8, masked_m, d, ref_d
|
||||
|
||||
|
||||
|
||||
125
tests/test_bf16.py
Normal file
125
tests/test_bf16.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
import random
|
||||
|
||||
import deep_gemm
|
||||
from deep_gemm.testing import (
|
||||
bench_kineto,
|
||||
calc_diff, count_bytes
|
||||
)
|
||||
from generators import (
|
||||
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, generate_normal,
|
||||
generate_m_grouped_contiguous, generate_m_grouped_masked
|
||||
)
|
||||
|
||||
|
||||
def test_gemm() -> None:
|
||||
print('Testing GEMM:')
|
||||
for _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(use_bf16=True):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
|
||||
acc_opt = f'acc={int(accumulate)}'
|
||||
|
||||
for test_alias in (False, True):
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
|
||||
func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}'
|
||||
if test_alias:
|
||||
a = a if major_a.is_k_major() else a.T
|
||||
b = b if major_b.is_k_major() else b.T
|
||||
assert a.is_contiguous() and b.is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, c=c)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.0001, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, '
|
||||
f'{diff:.5f}, alias={test_alias}')
|
||||
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_bf16=True)
|
||||
|
||||
cublas_t = 0
|
||||
t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True)
|
||||
if accumulate == 0 and out_dtype == torch.bfloat16:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
cublas_t = bench_kineto(lambda: a @ b.T, 'nvjet', suppress_kineto_output=True)
|
||||
except Exception:
|
||||
pass
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
|
||||
f'{cublas_t / t:.2f}x cuBLAS')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_contiguous() -> None:
|
||||
print('Testing m-grouped contiguous GEMM:')
|
||||
|
||||
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(use_bf16=True):
|
||||
major_opt = 'N' if major_a.is_k_major() else 'T'
|
||||
major_opt += 'T' if major_b.is_k_major() else 'N'
|
||||
|
||||
for test_alias in (False, True):
|
||||
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True)
|
||||
func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
|
||||
if test_alias:
|
||||
assert major_a.is_k_major()
|
||||
b = b if major_b.is_k_major() else b.mT
|
||||
assert a[0].is_contiguous() and b[0].is_contiguous()
|
||||
getattr(deep_gemm, func_name)(a, b, d, m_indices)
|
||||
d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d)
|
||||
diff = calc_diff(d, ref_d)
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}'
|
||||
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices)
|
||||
|
||||
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
def test_m_grouped_gemm_masked() -> None:
|
||||
print('Testing m-grouped masked GEMM:')
|
||||
|
||||
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
|
||||
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
|
||||
# Test correctness
|
||||
for i in range(10):
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
|
||||
|
||||
# Construct full cases
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
valid_m = masked_m.sum().item()
|
||||
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): '
|
||||
f'{t * 1e6:4.0f} us | '
|
||||
f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
torch.manual_seed(0)
|
||||
random.seed(0)
|
||||
|
||||
print('Library path:')
|
||||
print(f' > {deep_gemm.__path__}\n')
|
||||
|
||||
test_gemm()
|
||||
test_m_grouped_gemm_contiguous()
|
||||
test_m_grouped_gemm_masked()
|
||||
@@ -105,7 +105,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
# Test correctness
|
||||
for i in range(10):
|
||||
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
for j in range(num_groups):
|
||||
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
|
||||
assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
|
||||
@@ -115,7 +115,7 @@ def test_m_grouped_gemm_masked() -> None:
|
||||
|
||||
# noinspection PyShadowingNames
|
||||
def test_func():
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
# Test performance with fixed shapes
|
||||
valid_m = masked_m.sum().item()
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
from deep_gemm.testing import bench_kineto, count_bytes
|
||||
from deep_gemm.testing import bench_kineto, count_bytes, calc_diff
|
||||
from deep_gemm.utils import (
|
||||
align, ceil_div,
|
||||
per_token_cast_to_fp8, per_channel_cast_to_fp8,
|
||||
get_tma_aligned_size,
|
||||
get_mn_major_tma_aligned_tensor,
|
||||
get_mn_major_tma_aligned_packed_ue8m0_tensor,
|
||||
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
|
||||
)
|
||||
|
||||
from generators import (
|
||||
enumerate_transpose,
|
||||
enumerate_sf_layout,
|
||||
enumerate_k_grouped_sf_layout
|
||||
)
|
||||
@@ -43,29 +45,39 @@ def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) ->
|
||||
|
||||
def test_sf_layout_kernels() -> None:
|
||||
print('Testing SF layout kernels:')
|
||||
for mn, k, with_transpose, num_groups in enumerate_sf_layout():
|
||||
for mn, k, with_transpose, use_ue8m0, num_groups in enumerate_sf_layout():
|
||||
x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda')
|
||||
x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=True)
|
||||
x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0)
|
||||
fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1)
|
||||
fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2)
|
||||
|
||||
# Correctness
|
||||
packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
|
||||
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf)
|
||||
assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}'
|
||||
assert packed_sf.shape == ref_packed_sf.shape
|
||||
assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())])
|
||||
|
||||
# Test launch overhead
|
||||
launch_start_t = time.time_ns()
|
||||
get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
|
||||
launch_end_t = time.time_ns()
|
||||
if use_ue8m0:
|
||||
impl, name = get_mn_major_tma_aligned_packed_ue8m0_tensor, 'pack_fp32_into_ue8m0'
|
||||
packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
|
||||
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf)
|
||||
assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}'
|
||||
assert packed_sf.shape == ref_packed_sf.shape
|
||||
assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())])
|
||||
else:
|
||||
impl, name = get_mn_major_tma_aligned_tensor, 'transpose'
|
||||
transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf)
|
||||
tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, 128)
|
||||
if num_groups > 1:
|
||||
assert transposed_sf.size(0) == num_groups
|
||||
assert transposed_sf.stride(0) == tma_aligned_mn * sf_k
|
||||
assert transposed_sf.shape[-2:] == (mn, sf_k)
|
||||
assert transposed_sf.stride()[-2:] == (1, tma_aligned_mn)
|
||||
assert torch.equal(fp32_sf, transposed_sf)
|
||||
|
||||
# Performance
|
||||
t = bench_kineto(lambda: get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf), 'pack_fp32_into_ue8m0')
|
||||
print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}): '
|
||||
f'launch {(launch_end_t - launch_start_t) / 1e3:3.0f} us | {t * 1e6:4.0f} us | '
|
||||
f'{count_bytes(fp32_sf, packed_sf) / 1e9 / t:4.0f} GB/s')
|
||||
try:
|
||||
t = bench_kineto(lambda: impl(fp32_sf), name)
|
||||
except AssertionError as e:
|
||||
# Some cases may fallback to PyTorch impl
|
||||
t = 0
|
||||
print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}): '
|
||||
f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
|
||||
15
tests/test_lazy_init.py
Normal file
15
tests/test_lazy_init.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import deep_gemm
|
||||
|
||||
|
||||
def main(local_rank: int):
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
procs = [mp.Process(target=main, args=(i, ), ) for i in range(8)]
|
||||
for p in procs:
|
||||
p.start()
|
||||
for p in procs:
|
||||
p.join()
|
||||
Reference in New Issue
Block a user