Multiple updates and refactorings (#280)

This commit is contained in:
Zhean Xu
2026-01-16 17:06:52 +08:00
committed by GitHub
parent 3ccf40c53a
commit 0f5f266202
55 changed files with 2706 additions and 891 deletions

View File

@@ -53,11 +53,10 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
return;
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
const auto& [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, disable_ue8m0_cast);
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
@@ -66,7 +65,9 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
const auto& major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
// NOTES: Only granularity 128 and FP8 are exposed in the API
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k,
128, 128, major_a, major_b, compiled_dims, epilogue_type);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
@@ -229,8 +230,8 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
);
// Allocate output
constexpr int num_math_warp_groups = 4;
const auto& aligned_max_context_len = align(max_context_len, num_math_warp_groups * block_kv);
constexpr int split_kv = 256;
const auto& aligned_max_context_len = align(max_context_len, split_kv);
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat));
logits = logits.slice(-1, 0, max_context_len);
@@ -239,7 +240,7 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
if (arch_major == 9 or arch_major == 10) {
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, num_math_warp_groups);
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
@@ -251,7 +252,8 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
}
return logits;
}
#endif
#endif
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE

View File

@@ -139,7 +139,7 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims) {
// Shape must be `[B, M, K] @ [B, N, K].T`
const auto& major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
@@ -163,15 +163,16 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
return;
// Transform scaling factors
const auto& transformed_sfa = layout::transform_sf_into_required_layout(sfa, m, k, recipe, batch_size, true, false);
const auto& transformed_sfb = layout::transform_sf_into_required_layout(sfb, n, k, recipe, batch_size, false, false);
const auto& [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
const auto& major_sfb = get_major_type_ab(sfb);
sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
}
@@ -182,6 +183,7 @@ static void fp8_einsum(const std::string& expr,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe) {
// Some hardcoded Einstein sum kernels
const auto arch_major = device_runtime->get_arch_major();
if (expr == "bhr,hdr->bhd") {
// Permute dims to satisfy the order of (batch_size, m, n, k)
// (batch_size, m, n, k): (h, b, d, r)
@@ -190,7 +192,7 @@ static void fp8_einsum(const std::string& expr,
const auto& perm_d = d.permute({1, 0, 2});
const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,hdr->bhr") {
} else if (expr == "bhd,hdr->bhr" and arch_major == 10) {
// (batch_size, m, n, k): (h, b, r, d)
const auto& perm_a = a.first.permute({1, 0, 2});
const auto& perm_sfa = a.second.permute({1, 0, 2});
@@ -199,7 +201,7 @@ static void fp8_einsum(const std::string& expr,
const auto& perm_d = d.permute({1, 0, 2});
const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,bhr->hdr") {
} else if (expr == "bhd,bhr->hdr" and arch_major == 10) {
// (batch_size, m, n, k): (h, d, r, b)
const auto& perm_a = a.first.permute({1, 2, 0});
const auto& perm_sfa = a.second.permute({1, 2, 0});

View File

@@ -46,13 +46,16 @@ static bool early_return(const int& m, const int &n, const int& k,
}
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
std::optional<std::tuple<int, int, int>> recipe,
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
@@ -65,12 +68,11 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
check_major_type_cd(d);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a.first);
const auto& [n , k_] = get_shape<2>(b.first);
const auto& [m_, n_] = get_shape<2>(d);
const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Early return for trivial cases
@@ -78,88 +80,104 @@ static void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
return;
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
if (std::get<1>(recipe.value()) == 1) {
const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value());
if (gran_n == 1) {
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
const auto& major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
static void fp8_fp4_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
}
static void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
static void fp8_fp4_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
}
static void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
d, c, recipe, compiled_dims, disable_ue8m0_cast);
static void fp8_fp4_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
}
static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
std::optional<std::tuple<int, int, int>> recipe,
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
if (fp8_requires_k_major())
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(m_indices.is_contiguous());
DG_HOST_ASSERT(grouped_layout.is_contiguous());
// Type and shape checks
const auto& [m, k] = get_shape<2>(a.first);
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
const auto& [m_, n_] = get_shape<2>(d);
const auto& m__ = static_cast<int>(m_indices.numel());
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
// Layout checks
if (use_psum_layout) {
const auto& [num_groups_] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto& [m__] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
// D must be N-major
check_major_type_cd(d);
@@ -169,44 +187,48 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
return;
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, num_groups, disable_ue8m0_cast);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto& major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
DG_HOST_ASSERT(not use_psum_layout);
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, major_sfb, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
num_groups, m, n, k, major_a, major_b, compiled_dims);
sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b,
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast,
const bool& use_psum_layout) {
m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt);
}
static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
std::optional<std::tuple<int, int, int>> recipe,
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
@@ -214,15 +236,14 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
DG_HOST_ASSERT(masked_m.is_contiguous());
// Type and shape checks
const auto& [num_groups, m, k] = get_shape<3>(a.first);
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
const auto& num_groups___ = static_cast<int>(masked_m.numel());
const auto arch_major = device_runtime->get_arch_major();
const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [num_groups__, m_, n_] = get_shape<3>(d);
const auto num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
@@ -230,20 +251,18 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair<torch::Tensor, torch::T
check_major_type_cd(d);
// Transform scaling factors
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto& major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
@@ -262,9 +281,10 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& [_, m_] = get_shape<2>(a.first);
const auto& [__, n_] = get_shape<2>(b.first);
DG_HOST_ASSERT(m == m_ and n == n_);
const auto& [sum_k_ , m_] = get_shape<2>(a.first);
const auto& [sum_k__, n_] = get_shape<2>(b.first);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
// Contiguity checks
DG_HOST_ASSERT(a.first.is_contiguous());
@@ -283,8 +303,8 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
@@ -305,9 +325,7 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& sum_mk = a.first.numel();
const auto& sum_nk = b.first.numel();
int sum_k = 0;
for (const auto& k: ks)
sum_k += k;
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(sum_mk == static_cast<int64_t>(sum_k) * m);
DG_HOST_ASSERT(sum_nk == static_cast<int64_t>(sum_k) * n);
@@ -334,7 +352,7 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
@@ -404,25 +422,36 @@ static void bf16_gemm_tt(const torch::Tensor& a,
}
static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const torch::Tensor& m_indices,
const std::string& compiled_dims) {
const torch::Tensor& d, const torch::Tensor& grouped_layout,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a);
const auto& major_b = get_major_type_ab(b);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(m_indices.is_contiguous());
DG_HOST_ASSERT(grouped_layout.is_contiguous());
// Type and shape checks
const auto& [m, k] = get_shape<2>(a);
const auto& [num_groups, n, k_] = get_shape<3>(b);
const auto& [m_, n_] = get_shape<2>(d);
const auto& m__ = static_cast<int>(m_indices.numel());
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
// Layout checks
if (use_psum_layout) {
const auto& [num_groups_] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto& [m__] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
// D must be N-major
check_major_type_cd(d);
@@ -434,21 +463,24 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices,
DG_HOST_ASSERT(not use_psum_layout);
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10) {
sm100_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices,
num_groups, m, n, k, major_a, major_b, compiled_dims);
sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, compiled_dims,
use_psum_layout, expected_m_for_psum_layout);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const torch::Tensor& m_indices,
const std::string& compiled_dims) {
const torch::Tensor& d, const torch::Tensor& grouped_layout,
const std::string& compiled_dims,
const bool& use_psum_layout) {
m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2),
d, m_indices, compiled_dims);
d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt);
}
static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b,
@@ -498,9 +530,10 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
const std::string& compiled_dims) {
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& [_, m_] = get_shape<2>(a);
const auto& [__, n_] = get_shape<2>(b);
DG_HOST_ASSERT(m == m_ and n == n_);
const auto& [sum_k_ , m_] = get_shape<2>(a);
const auto& [sum_k__, n_] = get_shape<2>(b);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
// Contiguity checks
DG_HOST_ASSERT(a.is_contiguous());
@@ -563,38 +596,50 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// FP8 GEMMs
m.def("fp8_gemm_nt", &fp8_gemm_nt,
// FP8 FP4 GEMMs
m.def("fp8_fp4_gemm_nt", &fp8_fp4_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_nn", &fp8_gemm_nn,
m.def("fp8_fp4_gemm_nn", &fp8_fp4_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_tn", &fp8_gemm_tn,
m.def("fp8_fp4_gemm_tn", &fp8_fp4_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_tt", &fp8_gemm_tt,
m.def("fp8_fp4_gemm_tt", &fp8_fp4_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
m.def("m_grouped_fp8_fp4_gemm_nt_contiguous", &m_grouped_fp8_fp4_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false,
py::arg("use_psum_layout") = false,
py::arg("expected_m_for_psum_layout") = std::nullopt);
m.def("m_grouped_fp8_fp4_gemm_nn_contiguous", &m_grouped_fp8_fp4_gemm_nn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false,
py::arg("use_psum_layout") = false);
m.def("m_grouped_fp8_fp4_gemm_nt_masked", &m_grouped_fp8_fp4_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
@@ -606,6 +651,15 @@ static void register_apis(pybind11::module_& m) {
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
// FP8 GEMM alias names
m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt");
m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn");
m.attr("fp8_gemm_tn") = m.attr("fp8_fp4_gemm_tn");
m.attr("fp8_gemm_tt") = m.attr("fp8_fp4_gemm_tt");
m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous");
m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous");
m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked");
#endif
#if DG_TENSORMAP_COMPATIBLE
@@ -627,11 +681,14 @@ static void register_apis(pybind11::module_& m) {
py::arg("c") = std::nullopt,
py::arg("compiled_dims") = "mn");
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("compiled_dims") = "nk");
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("compiled_dims") = "nk",
py::arg("use_psum_layout") = false,
py::arg("expected_m_for_psum_layout") = std::nullopt);
m.def("m_grouped_bf16_gemm_nn_contiguous", &m_grouped_bf16_gemm_nn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("compiled_dims") = "nk");
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("compiled_dims") = "nk",
py::arg("use_psum_layout") = false);
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("compiled_dims") = "nk");

View File

@@ -0,0 +1,70 @@
#pragma once
#include "../utils/compatibility.hpp"
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp"
#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp"
#endif
namespace deep_gemm::hyperconnection {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void tf32_hc_prenorm_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& sqr_sum,
const std::optional<int>& num_splits) {
// A and B must be K-major, D must be N-major
DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K);
DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K);
check_major_type_cd(d);
// S must be contiguous
DG_HOST_ASSERT(sqr_sum.is_contiguous());
// Type and shape checks
const auto& [m, k ] = get_shape<2>(a);
const auto& [n, k_] = get_shape<2>(b);
if (num_splits.has_value()) {
const auto& [num_splits_, m_, n_] = get_shape<3>(d);
const auto& [num_splits__, m__] = get_shape<2>(sqr_sum);
DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1);
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
} else {
const auto& [m_, n_] = get_shape<2>(d);
const auto& [m__] = get_shape<1>(sqr_sum);
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
}
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat);
// Do nothing if the problem is empty
if (m == 0)
return;
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
} else if (arch_major == 10) {
sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"),
py::arg("num_splits") = std::nullopt);
#endif
}
} // namespace deep_gemm::hyperconnection

View File

@@ -1,20 +1,34 @@
#pragma once
#include "../utils/layout.hpp"
#include "../utils/compatibility.hpp"
#if DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/smxx_layout.hpp"
#endif
namespace deep_gemm::layout {
#if DG_TENSORMAP_COMPATIBLE
static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const std::tuple<int, int, int>& recipe,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_ab,
const std::optional<int>& num_groups,
const bool& is_sfa,
const bool& disable_ue8m0_cast) {
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
const auto& gran_k = std::get<2>(recipe);
const auto& arch_major = device_runtime->get_arch_major();
int gran_mn, gran_k;
if (recipe.has_value()) {
DG_HOST_ASSERT(not recipe_ab.has_value());
gran_mn = is_sfa ? std::get<0>(recipe.value()) : std::get<1>(recipe.value());
gran_k = std::get<2>(recipe.value());
} else {
DG_HOST_ASSERT(recipe_ab.has_value());
std::tie(gran_mn, gran_k) = recipe_ab.value();
}
// Pre-transform checks
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
@@ -22,30 +36,44 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return get_mn_major_tma_aligned_tensor(sf);
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
}
// (FP32, 128, 128) on SM90: no need to transform, check SFB requirements
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
// (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
const auto& broadcasted = gran_mn == 1 ? sf :
sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn));
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
}
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
DG_HOST_UNREACHABLE("Unknown SF transformation");
}
static std::tuple<torch::Tensor, torch::Tensor, int, int> transform_sf_pair_into_required_layout(
const torch::Tensor& sfa, const torch::Tensor& sfb,
const int& m, const int& n, const int& k,
std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::optional<int>& num_groups_a,
const std::optional<int>& num_groups_b,
const bool& disable_ue8m0_cast = false) {
DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value());
if (not recipe_a.has_value() and not recipe.has_value())
recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type());
const auto transformed_sfa = transform_sf_into_required_layout(sfa, m, k, recipe, recipe_a, num_groups_a, true, disable_ue8m0_cast);
const auto transformed_sfb = transform_sf_into_required_layout(sfb, n, k, recipe, recipe_b, num_groups_b, false, disable_ue8m0_cast);
const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value());
const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value());
return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b);
}
static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
@@ -69,17 +97,24 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
DG_HOST_UNREACHABLE("Unknown cases");
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
py::arg("sf"), py::arg("mn"), py::arg("k"),
py::arg("recipe") = std::nullopt, py::arg("recipe_ab") = std::nullopt,
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
py::arg("disable_ue8m0_cast") = false);
m.def("get_tma_aligned_size", &get_tma_aligned_size);
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
#endif
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
}
} // namespace deep_gemm::layout

View File

@@ -1,6 +1,8 @@
#pragma once
#if DG_TENSORMAP_COMPATIBLE
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
namespace deep_gemm::runtime {
@@ -18,10 +20,11 @@ static void register_apis(pybind11::module_& m) {
m.def("get_tc_util", [&]() {
return device_runtime->get_tc_util();
});
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
#if DG_TENSORMAP_COMPATIBLE
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
KernelRuntime::prepare_init(cuda_home_path_by_python);
#endif
});
}

View File

@@ -15,6 +15,10 @@
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
// Hyperconnection kernels
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh>
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh>
// Layout kernels
#include <deep_gemm/impls/smxx_layout.cuh>
#include <deep_gemm/impls/smxx_clean_logits.cuh>

View File

@@ -24,6 +24,7 @@ public:
static std::filesystem::path library_include_path;
static std::filesystem::path cuda_home;
static std::string library_version;
static std::filesystem::path cuobjdump_path;
static std::string get_library_version() {
std::vector<char> buffer;
@@ -45,6 +46,7 @@ public:
Compiler::library_include_path = Compiler::library_root_path / "include";
Compiler::cuda_home = cuda_home_path_by_python;
Compiler::library_version = get_library_version();
Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump";
}
std::string signature, flags;
@@ -56,6 +58,7 @@ public:
DG_HOST_ASSERT(not library_include_path.empty());
DG_HOST_ASSERT(not cuda_home.empty());
DG_HOST_ASSERT(not library_version.empty());
DG_HOST_ASSERT(not cuobjdump_path.empty());
// Cache settings
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
@@ -108,25 +111,57 @@ public:
// Compile into a temporary CUBIN
const auto tmp_cubin_path = get_tmp_file_path();
compile(code, dir_path, tmp_cubin_path);
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_PTX")) {
// Dump PTX if needed
const auto tmp_ptx_path = get_tmp_file_path();
compile(code, dir_path, tmp_cubin_path, tmp_ptx_path);
// Replace into the cache directory
std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx");
} else {
compile(code, dir_path, tmp_cubin_path);
}
// Replace into the cache directory
make_dirs(dir_path);
std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin");
const auto cubin_path = dir_path / "kernel.cubin";
std::filesystem::rename(tmp_cubin_path, cubin_path);
// Disassemble if needed
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_SASS")) {
// Dump into a temporary SASS
const auto tmp_sass_path = get_tmp_file_path();
disassemble(cubin_path, tmp_sass_path);
// Replace into the current directory
std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass");
}
// Put into the runtime cache
const auto& runtime = kernel_runtime_cache->get(dir_path);
const auto runtime = kernel_runtime_cache->get(dir_path);
DG_HOST_ASSERT(runtime != nullptr);
return runtime;
}
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0;
static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) {
// Disassemble the CUBIN file to SASS
const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str());
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running cuobjdump command: %s\n", command.c_str());
const auto [return_code, output] = call_external_command(command);
if (return_code != 0) {
printf("cuobjdump failed: %s\n", output.c_str());
DG_HOST_ASSERT(false and "cuobjdump failed");
}
}
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional<std::filesystem::path> &ptx_path = std::nullopt) const = 0;
};
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path);
class NVCCCompiler final: public Compiler {
std::filesystem::path nvcc_path;
@@ -164,17 +199,19 @@ public:
const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9);
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
"-O3 --expt-relaxed-constexpr --expt-extended-lambda",
flags, library_include_path.c_str(), arch);
}
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
void compile(const std::string &code, const std::filesystem::path& dir_path,
const std::filesystem::path &cubin_path,
const std::optional<std::filesystem::path> &ptx_path) const override {
// Write the code into the cache directory
const auto& code_path = dir_path / "kernel.cu";
put(code_path, code);
// Compile
const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC command: %s\n", command.c_str());
const auto& [return_code, output] = call_external_command(command);
@@ -183,6 +220,18 @@ public:
DG_HOST_ASSERT(false and "NVCC compilation failed");
}
// Compile to PTX if needed
if (ptx_path.has_value()) {
const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC PTX command: %s\n", ptx_command.c_str());
const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command);
if (ptx_return_code != 0) {
printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str());
DG_HOST_ASSERT(false and "NVCC PTX compilation failed");
}
}
// Check local memory usage
if (get_env("DG_JIT_PTXAS_CHECK", 0))
DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)")));
@@ -219,11 +268,13 @@ public:
// Override the compiler flags
// Only NVRTC >= 12.9 supports arch-specific family suffix
const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9);
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}",
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128",
flags, include_dirs, arch, pch_flags);
}
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
void compile(const std::string &code, const std::filesystem::path& dir_path,
const std::filesystem::path &cubin_path,
const std::optional<std::filesystem::path> &ptx_path) const override {
// Write the code into the cache directory
const auto& code_path = dir_path / "kernel.cu";
put(code_path, code);
@@ -266,6 +317,17 @@ public:
}
}
if (ptx_path.has_value()) {
// Get PTX size and data if needed
size_t ptx_size;
DG_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
std::string ptx_data(ptx_size, '\0');
DG_NVRTC_CHECK(nvrtcGetPTX(program, ptx_data.data()));
// Write into the file system
put(ptx_path.value(), ptx_data);
}
// Get CUBIN size and data
size_t cubin_size;
DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size));

View File

@@ -17,19 +17,7 @@ class DeviceRuntime {
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
public:
#if TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3)
// For PyTorch 2.3+, share the PyTorch cuBLASLt handle
DeviceRuntime() = default;
static cublasLtHandle_t get_cublaslt_handle() {
return at::cuda::getCurrentCUDABlasLtHandle();
}
static torch::Tensor get_cublaslt_workspace() {
return torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA));
}
#else
// Otherwise, create the cuBLASLt handle ourselves
// Create the cuBLASLt handle ourselves
cublasLtHandle_t cublaslt_handle{};
std::shared_ptr<torch::Tensor> cublaslt_workspace;
@@ -49,7 +37,6 @@ public:
torch::Tensor get_cublaslt_workspace() const {
return *cublaslt_workspace;
}
#endif
std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr) {

View File

@@ -40,10 +40,7 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx);
#if DG_TENSORMAP_COMPATIBLE
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled);
#endif
#if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API)
@@ -166,7 +163,6 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle&
void *ptr_args[] = { &args... };
return lazy_cuLaunchKernelEx(&config, kernel, ptr_args, nullptr);
}
#endif
} // namespace deep_gemm

View File

@@ -59,7 +59,8 @@ struct GemmConfig {
// Templated configs
GemmType gemm_type;
KernelType kernel_type;
at::ScalarType ab_dtype, cd_dtype;
MmaKind mma_kind;
at::ScalarType a_dtype, b_dtype, cd_dtype;
cute::UMMA::Major major_a;
cute::UMMA::Major major_b;
bool with_accumulation;
@@ -99,9 +100,9 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& num_stages, const MulticastConfig& multicast_config) {
const int& ab_elem_size = static_cast<int>(c10::elementSize(ab_dtype));
const int& ab_elem_size = static_cast<int>(get_element_size(mma_kind));
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));
const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m);
@@ -119,7 +120,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne
// SF shared memory
const auto& [smem_sfa_per_stage, smem_sfb_per_stage] =
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype);
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype);
const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k);
// M-barriers and tensor memory pointers
@@ -151,21 +152,35 @@ template <typename ArchSpec>
static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k, const int& num_groups,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const at::ScalarType& a_dtype, const at::ScalarType& b_dtype,
const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16);
const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4);
if (mma_kind == MmaKind::BF16) {
DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16);
} else {
DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4);
DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4);
}
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
// Select M/N block sizes
auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m);
if (gemm_type == GemmType::MGroupedContiguous)
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance
block_ms = std::vector{64, 128};
const auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype);
if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout)
block_ms = std::vector{64, 128}; // Exclude 256 for performance
auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype);
// NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B
// TODO: Optimize it
if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN)
block_ms = std::vector{128};
if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN)
block_ns = std::vector{128};
// K block size is selected in a fixed manner
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
const auto& block_k = (mma_kind == MmaKind::BF16 ? 64 : 128);
// Some util functions
const auto& get_num_blocks = [=](const int& block_m, const int& block_n) {
@@ -186,7 +201,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
for (const auto& block_n: block_ns) {
const int& num_waves = get_num_waves(block_m, block_n);
const auto& last_util = get_last_wave_util(block_m, block_n);
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, m, n, k, block_m, block_n, block_k))
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, mma_kind, cd_dtype, m, n, k, block_m, block_n, block_k))
continue;
bool success = false;
@@ -218,8 +233,16 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
// Decide the number of TMA multicasts and whether broadcast on A
MulticastConfig best_multicast_config = {1, false};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
// NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B
// TODO: Optimize it
if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN)
is_legal_on_a = false;
if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN)
is_legal_on_b = false;
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
@@ -236,14 +259,14 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
int best_num_stages = 0;
SharedMemoryConfig best_smem_config;
for (int num_stages = 32; num_stages > 0; -- num_stages) {
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
if (not ArchSpec::is_num_stages_legal(mma_kind, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
continue;
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
m, n, k,
best_block_m, best_block_n, block_k,
major_a, major_b,
ab_dtype, cd_dtype,
mma_kind, cd_dtype,
num_stages, best_multicast_config);
if (best_smem_config.smem_size <= smem_capacity) {
best_num_stages = num_stages;
@@ -255,7 +278,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
// Recompute the minimal number of SMs required
// NOTES: less L2 cache usage and less GPU frequency drop
int num_min_sms = num_sms;
if (ArchSpec::should_minimize_num_sms()) {
if (get_env<int>("DG_JIT_MINIMIZE_NUM_SMS", 0)) {
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
DG_HOST_ASSERT(num_min_sms <= num_sms);
@@ -264,7 +287,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
const auto& config = GemmConfig {
.gemm_type = gemm_type,
.kernel_type = kernel_type,
.ab_dtype = ab_dtype,
.mma_kind = mma_kind,
.a_dtype = a_dtype,
.b_dtype = b_dtype,
.cd_dtype = cd_dtype,
.major_a = major_a,
.major_b = major_b,
@@ -284,21 +309,22 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
// Only SM100 BF16 kernels support tensor core control
if (config.tc_util < 100)
DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and ab_dtype == torch::kBFloat16);
DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and mma_kind == MmaKind::BF16);
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b,
ab_dtype, cd_dtype, with_accumulation, num_sms);
mma_kind, a_dtype, b_dtype, cd_dtype, with_accumulation, num_sms);
static std::set<decltype(key)> printed;
if (printed.count(key) == 0) {
printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
"A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, "
"A major: %d, B major: %d, MMA kind: %d, A dtype: %s, B dtype: %s, CD dtype: %s, accumulation: %d, "
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
"swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n",
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
static_cast<int>(major_a), static_cast<int>(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype),
static_cast<int>(major_a), static_cast<int>(major_b), static_cast<int>(mma_kind),
c10::toString(a_dtype), c10::toString(b_dtype), c10::toString(cd_dtype),
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
static_cast<int>(best_multicast_config.is_multicast_on_a),

View File

@@ -53,18 +53,18 @@ struct SM100ArchSpec {
}
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
const int& block_m, const int& block_n, const MmaKind& mma_kind) {
constexpr int num_utccp_aligned_elems = 128;
switch (ab_dtype) {
case torch::kBFloat16: return {0, 0};
case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
switch (mma_kind) {
case MmaKind::BF16: return {0, 0};
case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
default: DG_HOST_UNREACHABLE("Unknown dtype");
}
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
// Layout A/D does not support `block_n % 16 != 0`
@@ -82,7 +82,7 @@ struct SM100ArchSpec {
// Check tensor memory validity
int sf_block_m = 0, sf_block_n = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind);
sf_block_m = sf_block_m_, sf_block_n = sf_block_n_;
}
if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512)
@@ -90,19 +90,15 @@ struct SM100ArchSpec {
// NOTES: when B is MN-major, we restrict `block_n` to multiples of 64,
// since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA
return major_b == cute::UMMA::Major::K or (block_n * c10::elementSize(ab_dtype)) % 64 == 0;
return major_b == cute::UMMA::Major::K or (block_n * get_element_size(mma_kind)) % 64 == 0;
}
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
return true;
}
static bool should_minimize_num_sms() {
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
@@ -129,14 +125,14 @@ struct SM100ArchSpec {
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
if (ab_dtype == torch::kBFloat16)
const MmaKind& mma_kind, const at::ScalarType& cd_dtype) {
if (mma_kind == MmaKind::BF16)
return {0, 0};
int smem_sfa_per_stage = 0;
int smem_sfb_per_stage = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind);
smem_sfa_per_stage = sf_block_m * 4;
smem_sfb_per_stage = sf_block_n * 4;
} else {

View File

@@ -60,7 +60,7 @@ struct SM90ArchSpec {
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
// SM90 FP32 output does not support `block_m == 256`
@@ -89,19 +89,15 @@ struct SM90ArchSpec {
return block_m <= 128 or block_n <= 128;
}
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
// Unrolling both stages and `num_former_iters` will cause large code size
if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
if (mma_kind == MmaKind::MXFP8FP4 and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
return num_stages <= 4;
return true;
}
static bool should_minimize_num_sms() {
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
@@ -134,8 +130,8 @@ struct SM90ArchSpec {
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
if (ab_dtype == torch::kBFloat16)
const MmaKind& mma_kind, const at::ScalarType& cd_dtype) {
if (mma_kind == MmaKind::BF16)
return {0, 0};
// NOTES: 128 is for 2D TMA alignment requirement

View File

@@ -37,11 +37,12 @@ static std::string to_string(const cute::UMMA::Major& major) {
static std::string to_string(const GemmType& type) {
switch (type) {
case GemmType::Normal: return "GemmType::Normal";
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
case GemmType::Batched: return "GemmType::Batched";
case GemmType::Normal: return "GemmType::Normal";
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout";
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
case GemmType::Batched: return "GemmType::Batched";
}
DG_HOST_UNREACHABLE("Unknown GEMM type");
}
@@ -51,6 +52,8 @@ static std::string to_string(const at::ScalarType& dtype) {
case torch::kInt: return "int";
case torch::kFloat: return "float";
case torch::kBFloat16: return "cutlass::bfloat16_t";
case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t";
case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t";
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
@@ -65,6 +68,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
@@ -98,6 +102,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
// Inner dim must be a multiple of 64B for .b4x16_p64
if (t.scalar_type() == kPackedFP4)
DG_HOST_ASSERT(gmem_inner_dim % 128 == 0);
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
@@ -126,6 +134,10 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
if (swizzle_mode != 0)
smem_dim_0 = swizzle_mode / elem_size;
// Inner dim must be a multiple of 64B for .b4x16_p64
if (t.scalar_type() == kPackedFP4)
DG_HOST_ASSERT(gmem_dim_0 % 128 == 0);
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
@@ -204,7 +216,7 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
int shape_mn, int shape_k,
const int& block_mn, const int& block_k,
const int& block_mn, const int& gran_k,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
@@ -215,7 +227,7 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
return make_tma_2d_desc(t,
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
block_mn, 1,
shape_mn,
swizzle_mode, swizzle_base,

View File

@@ -79,11 +79,11 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 64);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
torch::kBFloat16, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
@@ -104,7 +104,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
@@ -124,16 +124,25 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const torch::Tensor& grouped_layout,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 64);
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
// Otherwise, treat the contiguous layout as a whole.
const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m;
const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1;
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
gemm_type, KernelType::KernelNoSF,
// NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole
m, n, k, 1, major_a, major_b,
torch::kBFloat16, d.scalar_type(), false,
m_for_config, n, k, num_groups_for_config, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
@@ -154,14 +163,14 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = m_indices.data_ptr(),
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
@@ -179,11 +188,11 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 64);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::KernelNoSF,
expected_m, n, k, num_groups, major_a, major_b,
torch::kBFloat16, d.scalar_type(), false,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
@@ -204,7 +213,7 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
@@ -243,7 +252,8 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a,
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::KGroupedContiguous, KernelType::KernelNoSF,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
torch::kBFloat16, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Create tensor descriptors
@@ -290,7 +300,8 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K,
torch::kBFloat16, tensor_d.scalar_type(), false,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
@@ -337,7 +348,8 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN,
torch::kBFloat16, tensor_d.scalar_type(), false,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);

View File

@@ -15,10 +15,11 @@
namespace deep_gemm {
class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntime> {
class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8FP4Gemm1D1DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
int gran_k_a, gran_k_b;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
@@ -41,6 +42,7 @@ using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
{}, {},
{}, {},
{}, {}, {},
{}, {}, {},
@@ -50,12 +52,14 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{},
{}, {},
{}, {}, {},
{}
>);
}};
)",
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
args.gran_k_a, args.gran_k_b,
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
@@ -64,7 +68,8 @@ static void __instantiate_kernel() {{
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation,
to_string(args.gemm_config.a_dtype), to_string(args.gemm_config.b_dtype), to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
@@ -78,19 +83,20 @@ static void __instantiate_kernel() {{
}
};
static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const auto& aligned_k = align(k, 128);
static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto& cd = c.value_or(d);
@@ -110,14 +116,16 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
config.block_m, gran_k_a, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, 1, 0);
config.block_n, gran_k_b, 1, 0);
// Launch
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
@@ -131,24 +139,33 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const int& num_groups, const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
// Otherwise, treat the contiguous layout as a whole.
const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m;
const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1;
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
// NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
gemm_type, KernelType::Kernel1D1D,
m_for_config, n, k, num_groups_for_config, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
@@ -168,45 +185,48 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
config.block_m, gran_k_a, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, num_groups, 0);
config.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = m_indices.data_ptr(),
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
@@ -226,14 +246,16 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
config.block_m, gran_k_a, num_groups, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, num_groups, 0);
config.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
@@ -247,19 +269,19 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0, sum_sf_k = 0;
@@ -274,7 +296,8 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Create tensor descriptors
@@ -299,9 +322,11 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
config.block_n, config.block_k, 1, 0);
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.gran_k_a = 128,
.gran_k_b = 128,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
@@ -315,9 +340,9 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -330,7 +355,8 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Batched, KernelType::Kernel1D1D,
m, n, k, batch_size, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
@@ -364,9 +390,11 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
config.block_n, config.block_k, batch_size, 0);
// Launch
const SM100FP8Gemm1D1DRuntime::Args& args = {
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = batch_size,
.gran_k_a = 128,
.gran_k_b = 128,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
@@ -380,9 +408,9 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,149 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100BF16HCPrenormGemmRuntime final: public LaunchRuntime<SM100BF16HCPrenormGemmRuntime> {
public:
struct Args {
int m, n, k;
int block_m, block_n, block_k;
int num_splits;
int swizzle_cd_mode;
int num_stages;
int num_mma_threads, num_cast_and_reduce_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
float* sqr_sum;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_tf32_hc_prenorm_gemm_impl<
{}, {},
{}, {}, {},
{},
{},
{},
{}, {}
>);
}};
)",
args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.num_splits,
args.swizzle_cd_mode,
args.num_stages,
args.num_mma_threads, args.num_cast_and_reduce_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum));
}
};
static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& sqr_sum,
const int& m, const int& n, const int& k,
const int& num_splits) {
constexpr int block_m = 64;
constexpr int block_k = 64;
constexpr int num_mma_threads = 128;
constexpr int num_cast_and_reduce_threads = 128;
const int block_n = align(n, 16);
DG_HOST_ASSERT(n <= block_n);
DG_HOST_ASSERT(n <= 128 and n % 8 == 0);
DG_HOST_ASSERT(k % block_k == 0);
const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
: make_tma_3d_desc(d, n, m, num_splits,
block_n, block_m, 1,
static_cast<int>(d.stride(-2)),
static_cast<int>(d.stride(-3)),
swizzle_cd_mode);
// Calculate stages
int num_stages = 12, smem_size = 0;
while (num_stages > 0) {
const int smem_a_per_stage = block_m * block_k * static_cast<int>(sizeof(nv_bfloat16));
const int smem_b_per_stage = block_n * block_k * static_cast<int>(sizeof(float));
const int smem_cd = block_m * swizzle_cd_mode;
const int smem_barriers = (num_stages * 4 + 1) * 8;
const int smem_tmem_ptr = 4;
smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages +
smem_cd + smem_barriers + smem_tmem_ptr;
if (smem_size <= SM100ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split K: %d"
"stages: %d, shared memory: %d, swizzle CD: %d\n",
m, n, k, block_m, block_n, block_k, num_splits,
num_stages, smem_size, swizzle_cd_mode);
}
// Launch
const SM100BF16HCPrenormGemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.num_splits = num_splits,
.swizzle_cd_mode = swizzle_cd_mode,
.num_stages = num_stages,
.num_mma_threads = num_mma_threads,
.num_cast_and_reduce_threads = num_cast_and_reduce_threads,
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.sqr_sum = sqr_sum.data_ptr<float>()
};
const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code);
SM100BF16HCPrenormGemmRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -79,13 +79,11 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(not c.has_value());
const auto& aligned_k = align(k, 64);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
torch::kBFloat16, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
@@ -107,7 +105,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
@@ -138,7 +136,8 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
torch::kBFloat16, d.scalar_type(), false,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
@@ -192,7 +191,8 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedMasked, KernelType::KernelNoSF,
expected_m, n, k, num_groups, major_a, major_b,
torch::kBFloat16, d.scalar_type(), false,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
@@ -253,7 +253,8 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::KGroupedContiguous, KernelType::KernelNoSF,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
torch::kBFloat16, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Create tensor descriptors
@@ -300,7 +301,8 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K,
torch::kBFloat16, tensor_d.scalar_type(), false,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
@@ -346,7 +348,8 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN,
torch::kBFloat16, tensor_d.scalar_type(), false,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);

View File

@@ -88,7 +88,8 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
@@ -138,7 +139,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
@@ -156,7 +157,8 @@ static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Te
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits

View File

@@ -87,11 +87,11 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
@@ -118,7 +118,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = aligned_k,
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
@@ -148,11 +148,11 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
@@ -179,7 +179,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = aligned_k,
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
@@ -207,14 +207,14 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
@@ -241,7 +241,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = aligned_k,
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
@@ -261,4 +261,71 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& batch_size, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Batched, KernelType::Kernel1D2D,
m, n, k, batch_size, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
const auto& tensor_map_a = make_tma_3d_desc(a, k, m, batch_size,
config.block_k, load_block_m, 1,
a.stride(1),
a.stride(0),
config.smem_config.swizzle_a_mode);
const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
const auto& tensor_map_b = make_tma_3d_desc(b, k, n, batch_size,
config.block_k, load_block_n, 1,
b.stride(1),
b.stride(0),
config.smem_config.swizzle_b_mode);
const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m);
const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n);
const auto& tensor_map_d = make_tma_3d_desc(d, n, m, batch_size,
store_block_n, store_block_m, 1,
d.stride(1), d.stride(0),
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, batch_size, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = k,
.num_groups = batch_size,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,152 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90BF16HCPrenormGemmRuntime final: public LaunchRuntime<SM90BF16HCPrenormGemmRuntime> {
public:
struct Args {
int m, n, k;
int block_m, block_n, block_k;
int num_splits;
int swizzle_cd_mode;
int num_stages;
int num_math_threads, num_tma_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
float* sqr_sum;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_tf32_hc_prenorm_gemm_impl<
{}, {},
{}, {}, {},
{},
{},
{},
{}, {}
>);
}};
)",
args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.num_splits,
args.swizzle_cd_mode,
args.num_stages,
args.num_math_threads, args.num_tma_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum));
}
};
static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& sqr_sum,
const int& m, const int& n, const int& k,
const int& num_splits) {
constexpr int block_m = 64;
constexpr int block_k = 64;
constexpr int num_math_threads = 128;
constexpr int num_tma_threads = 128;
constexpr int num_threads = num_math_threads + num_tma_threads;
const int block_n = align(n, 16);
DG_HOST_ASSERT(n <= block_n);
// Only support small N for now
DG_HOST_ASSERT(n <= 32 and n % 8 == 0);
DG_HOST_ASSERT(k % block_k == 0);
const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
: make_tma_3d_desc(d, n, m, num_splits,
block_n, block_m, 1,
static_cast<int>(d.stride(-2)),
static_cast<int>(d.stride(-3)),
swizzle_cd_mode);
// Calculate stages
int num_stages = 12, smem_size = 0;
while (num_stages > 0) {
const int smem_a_per_stage = block_m * block_k * static_cast<int>(sizeof(nv_bfloat16));
const int smem_b_per_stage = block_n * block_k * static_cast<int>(sizeof(float));
const int smem_cd = block_m * swizzle_cd_mode;
const int smem_barriers = num_stages * 2 * 8;
smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages +
smem_cd + smem_barriers;
if (smem_size <= SM90ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split K: %d"
"stages: %d, shared memory: %d, swizzle CD: %d\n",
m, n, k, block_m, block_n, block_k, num_splits,
num_stages, smem_size, swizzle_cd_mode);
}
smem_size = SM90ArchSpec::smem_capacity;
// Launch
const SM90BF16HCPrenormGemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.num_splits = num_splits,
.swizzle_cd_mode = swizzle_cd_mode,
.num_stages = num_stages,
.num_math_threads = num_math_threads,
.num_tma_threads = num_tma_threads,
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.sqr_sum = sqr_sum.data_ptr<float>()
};
const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code);
SM90BF16HCPrenormGemmRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -37,7 +37,6 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
const bool& accumulate) {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
cudaDataType_t scale_type = CUDA_R_32F;
const int& math_sms = device_runtime->get_num_sms();
// Operation description
cublasLtMatmulDesc_t desc;
@@ -45,9 +44,13 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE
const int& math_sms = device_runtime->get_num_sms();
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
#if DG_FP8_COMPATIBLE
#endif
#if DG_FP8_COMPATIBLE and DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE
bool fp8_fast_accumulate = false;
if (a.scalar_type() == torch::kFloat8_e4m3fn)
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));

View File

@@ -174,13 +174,13 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
const int& num_math_warp_groups) {
const int& split_kv) {
const int num_specialized_threads = 128;
const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64);
const int num_math_warp_groups = split_kv / mma_m;
const int num_math_threads = num_math_warp_groups * 128;
const int num_extra_threads = device_runtime->get_arch_major() == 10 ? 128 : 0;
const int num_q_stages = 3, num_kv_stages = 3;
const int split_kv = num_math_warp_groups * block_kv;
DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0);
const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3);
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);
// Construct TMAs
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
@@ -196,23 +196,39 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
next_n * num_heads, 1, next_n * num_heads, 0);
// Calculate shared memory size
const int swizzle_alignment = head_dim * 8;
int smem_size = 0;
if (device_runtime->get_arch_major() == 9) {
const int swizzle_alignment = head_dim * 8;
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
} else {
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_kv_size_per_stage = split_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int smem_kv_scale_size_per_stage = split_kv * static_cast<int>(kv_cache_scales.element_size());
const int smem_weight_size_per_stage = next_n * num_heads * static_cast<int>(weights.element_size());
const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8;
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) +
smem_barriers + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
}
// Launch
const SMXXFP8PagedMQALogitsRuntime::Args& args = {
@@ -238,7 +254,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_sms,
num_specialized_threads + num_math_threads + num_extra_threads,
num_specialized_threads + num_math_threads,
smem_size)
};
const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args);

View File

@@ -3,6 +3,7 @@
#include "apis/attention.hpp"
#include "apis/einsum.hpp"
#include "apis/hyperconnection.hpp"
#include "apis/gemm.hpp"
#include "apis/layout.hpp"
#include "apis/runtime.hpp"
@@ -15,8 +16,10 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepGEMM C++ library";
// TODO: make SM80 incompatible issues raise errors
deep_gemm::attention::register_apis(m);
deep_gemm::einsum::register_apis(m);
deep_gemm::hyperconnection::register_apis(m);
deep_gemm::gemm::register_apis(m);
deep_gemm::layout::register_apis(m);
deep_gemm::runtime::register_apis(m);

View File

@@ -2,9 +2,16 @@
#include <torch/version.h>
#include <cuda.h>
#include <cuda_runtime.h>
// `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1
#define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1))
// `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1
#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010)
#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010)
// `cublasGetErrorString` is supported since CUDA Runtime API 11.4.2
#define DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE (CUDART_VERSION >= 11042)
// `CUBLASLT_MATMUL_DESC_FAST_ACCUM` and `CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET` are supported since CUDA Runtime API 11.8
#define DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE (CUDART_VERSION >= 11080)

View File

@@ -5,6 +5,8 @@
#include <string>
#include <sstream>
#include "compatibility.hpp"
namespace deep_gemm {
class DGException final : public std::exception {
@@ -74,6 +76,25 @@ do { \
#endif
#ifndef DG_CUBLASLT_CHECK
#if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE
inline const char* cublasGetStatusString(cublasStatus_t status) {
switch(status) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
default: return "Unknown cuBLAS error";
}
}
#endif
#define DG_CUBLASLT_CHECK(cmd) \
do { \
const auto& e = (cmd); \

View File

@@ -36,15 +36,34 @@ static bool fp8_requires_k_major() {
// Tensor utils
template <int N>
static auto get_shape(const torch::Tensor& t) {
DG_HOST_ASSERT(t.dim() == N);
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
}(std::make_index_sequence<N>());
}
static std::tuple<int, int> check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
auto [mn, k] = get_shape<2>(ab);
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
}
return std::make_tuple(mn, k);
}
static std::tuple<int, int, int> check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
auto [num_groups, mn, k] = get_shape<3>(ab);
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
}
return std::make_tuple(num_groups, mn, k);
}
// Recipe
static std::tuple<int, int, int>
get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) {
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat);
return {1, 128, 128};
@@ -70,7 +89,7 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf,
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
// Always do shape checks
const auto& sf_dtype = sf.scalar_type();
const auto sf_dtype = sf.scalar_type();
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
if (num_groups.has_value())

View File

@@ -6,6 +6,9 @@
namespace deep_gemm {
// TODO: Use `torch::kFloat4_e2m1fn_x2`
constexpr auto kPackedFP4 = torch::kUInt8;
template <typename T>
static T ceil_div(const T& a, const T& b) {
return (a + b - 1) / b;

View File

@@ -1,8 +1,6 @@
import os
import subprocess
import torch
from torch.version import cuda as cuda_version
from packaging import version
# Set some default environment provided at setup
try:
@@ -29,9 +27,15 @@ from ._C import (
cublaslt_gemm_tn, cublaslt_gemm_tt,
)
if version.parse(cuda_version) >= version.parse('12.1'):
try:
# DeepGEMM Kernels
from ._C import (
# FP8 FP4 GEMMs
fp8_fp4_gemm_nt, fp8_fp4_gemm_nn,
fp8_fp4_gemm_tn, fp8_fp4_gemm_tt,
m_grouped_fp8_fp4_gemm_nt_contiguous,
m_grouped_fp8_fp4_gemm_nn_contiguous,
m_grouped_fp8_fp4_gemm_nt_masked,
# FP8 GEMMs
fp8_gemm_nt, fp8_gemm_nn,
fp8_gemm_tn, fp8_gemm_tt,
@@ -55,6 +59,8 @@ if version.parse(cuda_version) >= version.parse('12.1'):
fp8_mqa_logits,
get_paged_mqa_logits_metadata,
fp8_paged_mqa_logits,
# Hyperconnection kernels
tf32_hc_prenorm_gemm,
# Layout kernels
transform_sf_into_required_layout,
get_mk_alignment_for_contiguous_layout
@@ -64,6 +70,9 @@ if version.parse(cuda_version) >= version.parse('12.1'):
# TODO: remove these later
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
except ImportError:
# Expected behavior for CUDA runtime version before 12.1
pass
# Some utils
from . import testing
@@ -71,7 +80,10 @@ from . import utils
from .utils import *
# Legacy Triton kernels for A100
from . import legacy
try:
from . import legacy
except Exception as e:
print(f'Failed to load legacy DeepGEMM A100 Triton kernels: {e}')
# Initialize CPP modules
def _find_cuda_home() -> str:
@@ -97,4 +109,4 @@ _C.init(
_find_cuda_home() # CUDA home
)
__version__ = '2.2.0'
__version__ = '2.3.0'

View File

@@ -51,6 +51,8 @@ struct Scheduler {
uint32_t current_group_idx = 0;
// Only used for masked layout
uint32_t current_m_cumsum = 0;
// Only used for countiguous psum layout
uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
// Only used for k-grouped layout
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
uint32_t next_group_idx, next_shape_k;
@@ -72,12 +74,16 @@ struct Scheduler {
current_shape_k = shape_k;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
num_blocks = num_m_blocks * num_n_blocks;
} else if (kGemmType == GemmType::MGroupedContiguous) {
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::MGroupedMasked) {
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::KGroupedContiguous) {
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
this->grouped_layout = grouped_layout;
current_psum_m = __ldg(grouped_layout);
num_m_blocks = ceil_div(current_psum_m, BLOCK_M);
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
this->grouped_layout = grouped_layout;
get_next_k_group(current_group_idx, current_shape_k);
next_group_idx = current_group_idx + 1;
@@ -131,7 +137,7 @@ struct Scheduler {
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
} else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
const auto offset = kWithGroupOffset ? current_group_idx : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
@@ -172,6 +178,28 @@ struct Scheduler {
}
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
while (true) {
// Within current group
if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
break;
// Move to check the next group
if (++ current_group_idx == kNumGroups)
return false;
// NOTES: `num_m_blocks` varies with the increase of the group index
last_psum_m = align(current_psum_m, 128u);
current_psum_m = __ldg(grouped_layout + current_group_idx);
current_m_block_cumsum += num_m_blocks;
num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M);
}
get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
// NOTES: `last_psum_m` is aligned with 128
m_block_idx += last_psum_m / BLOCK_M;
DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M");
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
while (true) {
// End of the task
@@ -248,6 +276,9 @@ struct Scheduler {
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
} else {
// Unreachable
DG_TRAP_ONLY_DEVICE_ASSERT(false);
}
}
};

View File

@@ -97,7 +97,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
// NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
// also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
@@ -131,8 +132,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id
}
__device__ __forceinline__
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) {
desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id;
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
}
@@ -154,6 +155,20 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
__device__ __forceinline__
void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbarrier_addr), "l"(cache_hint)
: "memory"
);
}
// UMMA versions with relaxed assertions
struct SM100_MMA_F16BF16_SS {
__device__ static void
@@ -231,4 +246,21 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS {
}
};
struct SM100_MMA_F16BF16_WS_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
} // namespace `deep_gemm::sm100`

View File

@@ -152,6 +152,51 @@ struct BF16MMASelector {
using type = decltype(select_type());
};
template <int N_, typename MMA>
struct TF32MMARS {
template <size_t ...Idx>
__forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
__forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 8;
static constexpr int kNumAccum = M * N / 128;
};
template <int N, bool kUseRS = true>
struct TF32MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (kUseRS) {
if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
}
}
static constexpr auto select_type() {
if constexpr (kUseRS) {
return TF32MMARS<N, decltype(select_mma())>();
} else {
DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
}
}
using type = decltype(select_type());
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {

View File

@@ -2,14 +2,36 @@
namespace deep_gemm {
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
Batched = 4
enum class MmaKind {
BF16 = 0,
MXFP8FP4 = 1,
};
constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) {
switch (mma_kind) {
case MmaKind::BF16: return 2;
case MmaKind::MXFP8FP4: return 1;
default: return 0;
}
}
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
Batched = 4,
MGroupedContiguousWithPsumLayout = 5,
};
constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) {
switch (gemm_type) {
case GemmType::MGroupedContiguous: return true;
case GemmType::MGroupedContiguousWithPsumLayout: return true;
default: return false;
}
}
enum class KernelType {
Kernel1D1D = 0,
Kernel1D2D = 1,

View File

@@ -148,6 +148,10 @@ __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
}
__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
}
template <typename old_t>
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});

View File

@@ -388,7 +388,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory

View File

@@ -14,6 +14,7 @@ namespace deep_gemm {
using namespace deep_gemm::sm100;
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kGranKA, uint32_t kGranKB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
@@ -22,7 +23,8 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
GemmType kGemmType, bool kWithAccumulation,
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
typename epilogue_type_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
@@ -45,16 +47,21 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
constexpr uint32_t kNumTMAStoreStages = 2;
constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t);
constexpr uint32_t kNumUTCCPAlignedElems = 128;
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad);
const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4);
const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4);
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
@@ -78,8 +85,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
@@ -89,7 +96,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
// NOTES: Make sure we have enough shared memory for UMMA padding
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
@@ -118,10 +125,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// SFA/SFB shared memory
@@ -225,28 +232,31 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::float_e4m3_t, kIsBatchedMM>(
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
tma_copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)));
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
tma_copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx));
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
}
// Arrive at full barriers
@@ -260,9 +270,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// TODO: refactor `UMMA_M` calculation
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e4m3_t, cutlass::float_e4m3_t,
float, cutlass::float_ue8m0_t,
constexpr uint32_t UMMA_K = 32;
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto sf_desc = make_sf_desc(nullptr);
@@ -313,19 +322,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
// SFA and SFB copy
// TODO: process shared memory descriptor by addition
// TODO: process shared memory descriptor by addition
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
}
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
@@ -337,17 +347,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
@@ -391,11 +404,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
full_barriers[stage_idx]->wait(phase);
// Transpose for UTCCP at certain stages
const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
@@ -454,7 +470,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
// Store into shared memory

View File

@@ -143,7 +143,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + gridDim.x, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q];
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
@@ -152,8 +152,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
@@ -278,9 +279,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const auto& v_offset = lane_idx;
// Preload weights
constexpr uint32_t kNumWeightsInReg = 52;
constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
float weights[BLOCK_Q][kNumWeightsInReg];
DG_STATIC_ASSERT(kNumWeightsInReg <= kNumHeads and kNumWeightsInReg % 4 == 0, "Invalid kNumWeightsInReg");
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
@@ -337,7 +338,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
float* accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
@@ -367,14 +368,14 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
sum_1 = transform_smem(j + 2, sum_1);
}
float result = sum_0.x + sum_0.y + sum_1.x + sum_1.y;
result *= scale_kv;
auto sum = __fadd2_rn(sum_0, sum_1);
float result = scale_kv * (sum.x + sum.y);
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
if constexpr (kIsCompressedLogits) {
if (kv_offset + v_offset >= seq_k_start[i])
if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result;
} else {
logits[q_idx * stride_logits + kv_offset + v_offset] = result;

View File

@@ -22,8 +22,9 @@ template <uint32_t kNextN, uint32_t kNumHeads,
bool kIsContextLens2D,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads>
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1)
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint64_t logits_stride, const uint64_t block_table_stride,
const uint32_t* context_lens, float* logits,
@@ -40,9 +41,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const auto& lane_idx = get_lane_idx();
// Prefetch TMA descriptors
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
@@ -54,78 +53,58 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
static constexpr uint32_t SMEM_UMMA_SIZE = kNumMathWarpGroups * 2 * 8 + static_cast<uint32_t>(sizeof(uint32_t));
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q data and barriers on shared memory
// Q and KV data on shared memory
auto smem_q = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
// Separate math warpgroups and tma load warps into KV groups
// Each math warpgroup corresponds to a tma load warp
const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
// Per group KV data and barriers on shared memory
const auto& smem_kv_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * i);
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
});
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
// UMMA barriers and TMEM pointer on shared memory
auto umma_barrier_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kNumMathWarpGroups);
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 16
const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 20
const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 4)); // 20
const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
// Initialize barriers
if (is_tma_load_warp and cute::elect_one_sync()) {
if (kv_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
}
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
}
if (kv_group_idx < kNumMathWarpGroups) {
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(128);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(kNumMathThreads);
}
cutlass::arch::fence_barrier_init();
}
@@ -144,12 +123,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 32;
constexpr uint32_t kNumMathRegisters = 104;
constexpr uint32_t kNumSpecializedRegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Scheduler
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(batch_size, blockIdx.x, context_lens, schedule_meta);
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Q and KV pipeline
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
@@ -161,19 +141,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// UMMA settings
// Construct instruction with layout F
constexpr uint32_t UMMA_M = 64;
// Construct instruction with layout D
constexpr uint32_t UMMA_M = 128;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = kNextN * kNumHeads;
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
if (is_tma_load_warp) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (kv_group_idx >= kNumMathWarpGroups)
return;
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
if (kv_group_idx == 0 and cute::elect_one_sync()) {
if (cute::elect_one_sync()) {
tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
@@ -199,6 +178,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;
num_kv = next_num_kv;
// Read KV block index
// TODO: deal with `-1`?
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0);
}
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
// Wait Q consumer release and issue TMA Q
if (prefetch_q) {
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
@@ -206,25 +193,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
issue_tma_q(q_stage_idx, q_idx + 1);
}
// Read KV block index
// TODO: deal with `-1`?
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
__ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
}
const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
int kv_block_idx[kNumBlocksPerSplit];
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
kv_block_idx_ptr += kNumBlocksPerSplit;
// Wait KV consumer release
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
if (cute::elect_one_sync()) {
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
0, 0, 1, kv_block_idx[i]);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
0, kv_block_idx[i]);
}
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
@@ -245,32 +233,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
uint32_t q_idx = batch_size, kv_idx;
uint32_t next_q_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 0;
auto smem_kv = PatternVisitor([&](const uint32_t& stage_idx) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_SIZE_PER_STAGE * stage_idx);
});
uint32_t umma_phase = 1;
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
if (q_idx != next_q_idx) {
if (q_idx != next_q_idx)
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
}
q_idx = next_q_idx;
kv_idx = next_kv_idx;
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(umma_phase & 1);
empty_umma_barriers[i]->wait(umma_phase);
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_kv[kv_stage_idx] + i * SMEM_KV_PIPE_SIZE, 0, k * UMMA_K);
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_q[q_stage_idx], 0, k * UMMA_K);
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
@@ -285,10 +267,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Offsets
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
float weights[kNextN][kNumHeads / 4];
const auto& sub_warp_offset = (warp_idx % 4) * 16;
const auto& v_0_offset = lane_idx / 4 + 0;
const auto& v_1_offset = lane_idx / 4 + 8;
const uint32_t thread_idx = threadIdx.x;
// Weights
constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads));
float weights[kNextN][kNumWeightsInReg];
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
uint32_t q_idx = batch_size, kv_idx;
@@ -310,9 +294,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
}
@@ -321,75 +304,80 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV;
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
// Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset),
ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset));
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx);
empty_umma_barriers[warpgroup_idx]->arrive();
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait(umma_phase & 1);
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
umma_phase ^= 1;
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
// Reduce over the head dim and store
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
uint32_t shifted_accum[kNumLDTMElems];
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
empty_umma_barriers[warpgroup_idx]->arrive();
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
// Load from the tensor memory
constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128;
uint32_t shifted_accum[kNumLDTMElems];
DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM");
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 16) {
cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
// Transform
const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) {
auto a = make_float2(fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k]), 0),
fmaxf(*reinterpret_cast<float*>(&shifted_accum[j * 4 + k + 2]), 0));
auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]);
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
// Intra-thread reduction
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 8; ++ j) {
sum_0 = transform_2(j, 0, sum_0);
sum_1 = transform_2(j, 1, sum_1);
for (int j = 0; j < kNumWeightsInReg; j += 4) {
sum_0 = transform_reg(j, sum_0);
sum_1 = transform_reg(j + 2, sum_1);
}
auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv);
// Inter-thread reduction
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < 2; ++ j) {
const auto& offset = 1u << j;
v.x += __shfl_xor_sync(0xffffffffu, v.x, offset);
v.y += __shfl_xor_sync(0xffffffffu, v.y, offset);
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
sum_0 = transform_smem(j, sum_0);
sum_1 = transform_smem(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
float result = scale_kv * (sum.x + sum.y);
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
logits[kv_offset + i * logits_stride + v_0_offset] = v.x;
logits[kv_offset + i * logits_stride + v_1_offset] = v.y;
logits[kv_offset + i * logits_stride + thread_idx] = result;
}
}
} else {

View File

@@ -0,0 +1,345 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/reduction.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
__device__ __forceinline__
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
// Calculate the index of the bank group to be written in the atom
const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
// - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)`
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups;
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
col ^= row % (kSwizzleMode / kSwizzleBase);
return row * 128 + col * kSwizzleBase;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumSplits,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
float* sqr_sum) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Configs
constexpr uint32_t kNumCastStages = 2;
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
constexpr auto kMajorA = cute::UMMA::Major::K;
constexpr auto kMajorB = cute::UMMA::Major::K;
DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages");
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads");
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 4 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
full_cast_barriers[i]->init(kNumCastAndReduceThreads);
empty_barriers[i]->init(1);
empty_cast_barriers[i]->init(1);
}
tmem_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
const uint32_t m_block_idx = block_idx / kNumSplits;
const uint32_t k_split_idx = block_idx % kNumSplits;
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
const uint32_t m_offset = shape_m * k_split_idx;
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
// Dispatch warps into different roles
if (warp_idx < kNumMMAThreads / 32) {
// TMA load warp
if (warp_idx == 0 and cute::elect_one_sync()) {
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
// Compute offsets
uint32_t m_idx = m_block_idx * BLOCK_M;
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
}
// MMA issue warp
if (warp_idx == 1) {
// Make instruction descriptor
constexpr uint32_t UMMA_M = BLOCK_M;
constexpr uint32_t UMMA_N = BLOCK_N;
constexpr uint32_t UMMA_K = 32 / sizeof(float);
constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float);
using umma_t = cute::SM100_MMA_TF32_TS<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
BLOCK_M, BLOCK_N, kMajorA, kMajorB>;
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Launch MMAs
// We can not unroll this part
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
const auto& cast_stage_idx = s % kNumCastStages;
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
tcgen05_after_thread_sync();
// Issue UMMA
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
}
// Commit
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_cast_barriers[cast_stage_idx]));
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
}
// Commit to epilogue threads
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
}
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Only support layout F (M = 64) and D (M = 128)
DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M");
// Wait UMMA arrival
tmem_full_barrier->wait(0);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Source and destination memory address
uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup;
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) + // Base pointer
warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset
get_swizzled_smem_offset<kSwizzleCDMode>(i, lane_idx); // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
if constexpr (BLOCK_M == 64)
__syncwarp();
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0);
if (warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kNumSplits == 1) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
} else {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
}
cute::tma_store_arrive();
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
} else {
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads");
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32;
// TODO: make even larger block K
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
// Launch reductions
float2 sum[2] = {float2{0, 0}, float2{0, 0}};
#pragma unroll kNumStages
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
// Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b)
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
const auto& smem_base_ptr = reinterpret_cast<uint8_t*>(smem_a[stage_idx]) + // Base pointer
sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset
// 4 lanes shared a bank group
uint32_t uint32_values[2][kNumLoads];
DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads");
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; i += 2) {
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
uint32_values[0][i + 1], uint32_values[1][i + 1],
smem_ptr);
}
// Wait tensor memory empty
const auto& cast_stage_idx = s % kNumCastStages;
empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1);
// Cast, reduce and store into tensor memory
float2 fp32x2_values[2][kNumLoads];
const auto& upper_view = reinterpret_cast<uint32_t*>(&fp32x2_values[0]);
const auto& lower_view = reinterpret_cast<uint32_t*>(&fp32x2_values[1]);
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; ++ i) {
#pragma unroll
for (uint32_t u = 0; u < 2; ++ u) {
fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast<nv_bfloat162*>(&uint32_values[u][i]));
sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]);
}
// Store upper and lower part at the same time
const auto idx_0 = i * 2, idx_1 = i * 2 + 1;
cute::SM100_TMEM_STORE_16dp256b1x::copy(
upper_view[idx_0], upper_view[idx_1],
lower_view[idx_0], lower_view[idx_1],
cast_stage_idx * BLOCK_K + i * 8);
}
cutlass::arch::fence_view_async_tmem_store();
// Arrive for issuing MMAs
tcgen05_before_thread_sync();
full_cast_barriers[cast_stage_idx]->arrive();
}
// Intra-warp reduction and write back
#pragma unroll
for (uint32_t u = 0; u < 2; ++ u) {
const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y);
const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
if (lane_idx % 4 == 0 and m_idx < shape_m)
sqr_sum[m_offset + m_idx] = reduced_sum;
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
} // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -350,7 +350,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
// Use TMA store to write back to global memory
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;

View File

@@ -171,20 +171,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
empty_barriers[stage_idx]->wait(phase ^ 1);
// Issue TMA A
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t k_idx = k_block_idx * BLOCK_K;
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, &full_barrier,
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
num_tma_multicast_a, batch_idx);
tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_block_idx),
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
num_tma_multicast_a);
// Issue TMA B
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, &full_barrier,
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
num_tma_multicast_b, batch_idx);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
}
@@ -222,7 +225,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto previous_group_offset = scheduler.get_global_idx<true>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
@@ -413,9 +416,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset),
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
auto n_idx = epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset);
auto m_idx = scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx);
if constexpr (kGemmType == GemmType::Batched) {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr,
n_idx, m_idx, scheduler.current_group_idx);
} else {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx);
}
cute::tma_store_arrive();
}
__syncwarp();

View File

@@ -127,7 +127,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + gridDim.x, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q];
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
@@ -136,8 +136,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
@@ -304,9 +305,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// NOTES: we have redundant writes here, consider more carefully
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
if constexpr (kIsCompressedLogits) {
if (kv_offset + v_0_offset >= seq_k_start[i])
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0;
if (kv_offset + v_1_offset >= seq_k_start[i])
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1;
} else {
logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0;

View File

@@ -58,7 +58,7 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne
}
template <uint32_t kNextN, bool kIsContextLens2D,
uint32_t BLOCK_KV, uint32_t kNumMathWarpGroups>
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
struct PagedMQALogitsScheduler {
uint32_t batch_size;
const uint32_t* context_lens;
@@ -79,8 +79,8 @@ struct PagedMQALogitsScheduler {
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups;
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups;
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
current_num_kv = get_num_kv(current_q_idx);
}
@@ -93,7 +93,7 @@ struct PagedMQALogitsScheduler {
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
return false;
current_kv_idx += kNumMathWarpGroups;
current_kv_idx += kNumBlocksPerSplit;
if (current_kv_idx >= current_num_kv) {
++ current_q_idx;
current_kv_idx = 0;

View File

@@ -0,0 +1,287 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/reduction.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
__device__ __forceinline__
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
col ^= row % kGroupsInSwizzleRange;
return (row * kNumBankGroups + col) % kGroupsInSwizzleRange;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumSplits,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
float* sqr_sum) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// kSwizzleAMode and kSwizzleBMode must be 128 for now
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode");
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads");
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(128);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
__syncthreads();
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
const uint32_t m_block_idx = block_idx / kNumSplits;
const uint32_t k_split_idx = block_idx % kNumSplits;
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
const uint32_t m_offset = shape_m * k_split_idx;
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 256;
// TMA load warp
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
// Compute offsets
uint32_t m_idx = m_block_idx * BLOCK_M;
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
}
} else if (warp_idx < kNumMathThreads / 32) {
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
constexpr uint32_t WGMMA_M = 64;
constexpr uint32_t WGMMA_N = BLOCK_N;
constexpr uint32_t WGMMA_K = 8;
using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
float accum[WGMMA::kNumAccum] = {0};
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
float sqr_sum_acc_0 = 0;
float sqr_sum_acc_1 = 0;
#pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128;
constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K;
float a[kNumRegPerWgmma * kNumWgmmaPerBlockK];
// Assume swizzle A mode is 128
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
// Load BF16 A fragment from shared memory into registers, and transpose to FP32
uint32_t row = warp_idx * 16 + lane_idx / 4;
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; ++ i) {
// Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a
uint32_t bank_group_idx = (row ^ i) % 8;
nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
uint32_t elem_offset = lane_idx % 4;
nv_bfloat16 a_bf16[kNumRegPerWgmma];
a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset];
a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4];
a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset];
a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4];
auto a_bf16x2_ptr = reinterpret_cast<nv_bfloat162*>(a_bf16);
auto a_float2_ptr = reinterpret_cast<float2*>(a);
float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]);
float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]);
a_float2_ptr[i * 2 + 0] = a_float2_0;
a_float2_ptr[i * 2 + 1] = a_float2_1;
sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x;
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
}
warpgroup_wait<0>();
if (s > 0)
empty_barriers[(s - 1) % kNumStages]->arrive();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K");
#pragma unroll
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
#pragma unroll
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
}
}
warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
}
const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
if (lane_idx % 4 == 0) {
if (m_idx < shape_m)
sqr_sum[m_offset + m_idx] = reduced_sum_0;
if (m_idx + 8 < shape_m)
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
}
warpgroup_wait<0>();
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
// Write accum to shared memory
// Every 2 threads (one pair) will write to the same bank group (16 bytes).
// Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d
uint32_t is_odd_pair = lane_idx / 2 % 2;
// Four threads per group; write the data to the same row.
uint32_t row_idx = lane_idx / 4;
// Even/odd index pairs write to the same column, we need to reorder idx:
// group even pair indices consecutively, and likewise for odd ones.
uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx;
auto shifted_smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) +
(warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows
lane_idx % 2 * 8; // One thread of a pair writes 8 bytes
#pragma unroll
for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) {
// Get the swizzled bank group index (16 bytes per group)
uint32_t bank_group_idx = get_swizzled_bank_group_idx<kSwizzleCDMode>(i + is_odd_pair, reordered_pair_idx);
auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group
// 0/1 write to the same row, 2/3 write to another row
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
st_shared(smem_ptr, values[0], values[1]);
st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, 1);
// Issue TMA stores
if (warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kNumSplits == 1) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
} else {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
}
cute::tma_store_arrive();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
} // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -5,6 +5,8 @@ from typing import Iterable
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
if denominator == 0: # Which means that all elements in x and y are 0
return 0.0
sim = 2 * (x * y).sum() / denominator
return 1 - sim

View File

@@ -1,10 +1,16 @@
from .._C import (
get_tma_aligned_size,
get_mk_alignment_for_contiguous_layout,
get_mn_major_tma_aligned_tensor,
get_mn_major_tma_aligned_packed_ue8m0_tensor,
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
)
try:
from .._C import (
get_tma_aligned_size,
get_mn_major_tma_aligned_tensor,
get_mn_major_tma_aligned_packed_ue8m0_tensor,
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
)
except ImportError:
# Expected behavior for CUDA runtime version before 12.1
pass
# Valid for all CUDA versions
from .._C import get_mk_alignment_for_contiguous_layout
# Some alias
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout

View File

@@ -15,35 +15,35 @@ def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
padded_n = align(n, 128)
padded_n = align(n, gran_k)
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
x_padded[:, :n] = x
x_view = x_padded.view(m, -1, 128)
x_view = x_padded.view(m, -1, gran_k)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(0) % 128 == 0
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(0) % gran_k == 0
m, n = x.shape
x_view = x.view(-1, 128, n)
x_view = x.view(-1, gran_k, n)
x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device)
x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
@@ -58,3 +58,50 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled, sf.squeeze()
def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
ax = x.abs().clamp_max(6.0)
# {0, 0.5, 1, 1.5, 2, 3, 4, 6}
# midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0
boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0],
device=x.device, dtype=ax.dtype)
idx = torch.bucketize(ax, boundaries)
code = idx.to(torch.uint8)
sign = (x < 0) & (idx != 0)
code = code | (sign.to(torch.uint8) << 3)
return code # uint8, 0..15
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
assert n % 2 == 0
padded_n = align(n, gran_k)
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
x_padded[:, :n] = x
x_view = x_padded.view(m, -1, gran_k)
x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4)
sf = x_amax / 6.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
codes2 = codes.view(m, padded_n // 2, 2)
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
return packed[:, :n // 2].contiguous(), sf
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
assert a.dtype == torch.uint8
assert a.dim() == 2
m, n2 = a.shape
n = n2 * 2
assert (m % 2) == 0
lo = a & 0x0F
hi = (a >> 4) & 0x0F
codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
codes[:, 0::2], codes[:, 1::2] = lo, hi
codes_t = codes.transpose(0, 1).contiguous()
codes2 = codes_t.view(n, m // 2, 2)
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
return out.contiguous()

View File

@@ -1,12 +1,13 @@
import enum
import random
import torch
from typing import Generator, List
from typing import Generator, List, Optional, Tuple
from deep_gemm.testing import get_arch_major
from deep_gemm.utils import (
align, ceil_div,
per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8,
per_token_cast_to_fp4, transpose_packed_fp4,
get_mk_alignment_for_contiguous_layout
)
@@ -35,6 +36,51 @@ class MajorTypeAB(enum.Enum):
def is_mn_major(self):
return self.value == 1
class QuantConfig:
_legacy_quant_config = (128, 128, False, False)
def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config):
self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value
def print(self):
print(f' > Testing with gran_k_a={self.gran_k_a}, gran_k_b={self.gran_k_b}, '
f'is_fp4_a={self.is_fp4_a}, is_fp4_b={self.is_fp4_b}')
def is_legacy(self) -> bool:
return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config
def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]:
recipe, recipe_a, recipe_b = None, None, None
if self.is_legacy():
recipe = (1, 1, 128) if is_wgrad else None
else:
recipe_a = (1, self.gran_k_a)
recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b)
return recipe, recipe_a, recipe_b
def max_diff(self) -> float:
if self.is_fp4_a and self.is_fp4_b:
return 0.02
if self.is_fp4_a or self.is_fp4_b:
return 0.01
return 0.001
@staticmethod
def get_list_from_dtype(dtype: torch.dtype) -> List:
if dtype == torch.bfloat16:
return [None]
quant_config_list = [QuantConfig()]
if get_arch_major() == 10:
quant_config_list.append(QuantConfig((128, 32, False, True)))
return quant_config_list
def reset_seed(seed: int = 0):
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def get_ue8m0_usage(kernel_type: KernelType) -> bool:
@@ -60,9 +106,14 @@ def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
yield major_a, major_b
def get_psum_layout_usage() -> tuple:
return (False, True) if get_arch_major() == 10 else (False, )
def enumerate_normal(dtype: torch.dtype) -> Generator:
assert dtype in (torch.float8_e4m3fn, torch.bfloat16)
quant_config_list = QuantConfig.get_list_from_dtype(dtype)
fp32_output_nk = [(256, 7168), (129280, 7168)]
bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]
m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ]
@@ -73,39 +124,61 @@ def enumerate_normal(dtype: torch.dtype) -> Generator:
nk_list += fp32_output_nk
for kernel_type in get_kernel_types(dtype):
# Forward
for m in m_fwd_list:
for i in range(len(nk_list)):
n, k = nk_list[i]
out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float
yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
for quant_config in quant_config_list:
if len(quant_config_list) > 1:
quant_config.print()
reset_seed()
# Backward
for m in m_bwd_list:
for n, k in nk_list:
override_major = MajorTypeAB.MNMajor
override_kernel_type = kernel_type
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
override_major = MajorTypeAB.KMajor
override_kernel_type = KernelType.Kernel1D1D
yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad
yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
# Forward
for m in m_fwd_list:
for i in range(len(nk_list)):
n, k = nk_list[i]
out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float
yield kernel_type, quant_config, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype
# Backward
for m in m_bwd_list:
for n, k in nk_list:
override_major = MajorTypeAB.MNMajor
override_kernel_type = kernel_type
if get_arch_major() == 9 and dtype == torch.float8_e4m3fn:
override_major = MajorTypeAB.KMajor
override_kernel_type = KernelType.Kernel1D1D
yield kernel_type, quant_config, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad
yield override_kernel_type, quant_config, n, m, k, override_major, override_major, True, torch.float # Wgrad
yield override_kernel_type, quant_config, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad
def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator:
quant_config_list = QuantConfig.get_list_from_dtype(dtype)
m_group_list = [(4, 8192), (8, 4096)]
n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)]
for kernel_type in get_kernel_types(dtype):
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn):
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
for quant_config in quant_config_list:
if len(quant_config_list) > 1:
quant_config.print()
for use_psum_layout in get_psum_layout_usage():
reset_seed()
for num_groups, expected_m_per_group in m_group_list:
for n, k in n_k_list:
for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn):
yield kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout
def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
quant_config_list = QuantConfig.get_list_from_dtype(dtype)
max_m = 4096
m_group_list = [(6, 1024), (32, 192), (32, 50)]
n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)]
for kernel_type in get_kernel_types(dtype):
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
for n, k in ((4096, 7168), (7168, 2048), ):
yield kernel_type, num_groups, max_m, m, n, k
for quant_config in quant_config_list:
if len(quant_config_list) > 1:
quant_config.print()
for use_psum_layout in get_psum_layout_usage():
reset_seed()
for num_groups, m in m_group_list:
for n, k in n_k_list:
yield kernel_type, quant_config, num_groups, max_m, m, n, k, use_psum_layout
def enumerate_k_grouped_contiguous(dtype: torch.dtype):
@@ -145,11 +218,46 @@ def enumerate_transpose():
yield mn + delta, k
def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool,
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
if is_fp4:
x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1])
else:
x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1])
return x
def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool,
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
num_groups, mn, k = x.size()
if is_fp4:
x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else \
torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8),
torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1])
x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1])
else:
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \
else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1])
return x
def generate_normal(m: int, n: int, k: int,
major_a: MajorTypeAB, major_b: MajorTypeAB,
accumulate: bool, out_dtype: torch.dtype,
kernel_type: KernelType,
use_ue8m0: bool = False, use_bf16: bool = False):
use_ue8m0: bool = False, use_bf16: bool = False,
quant_config: Optional[QuantConfig] = None):
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \
@@ -161,25 +269,28 @@ def generate_normal(m: int, n: int, k: int,
a = a if major_a.is_k_major() else a.T.contiguous().T
b = b if major_b.is_k_major() else b.T.contiguous().T
return a, b, c, d, ref_d
quant_config = QuantConfig() if quant_config is None else quant_config
a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0,
use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate))
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \
else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
return a_fp8, b_fp8, c, d, ref_d
return a, b, c, d, ref_d
def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int,
major_a: MajorTypeAB, major_b: MajorTypeAB,
use_ue8m0: bool = False, use_bf16: bool = False):
use_ue8m0: bool = False, use_bf16: bool = False,
use_psum_layout: bool = False,
quant_config: Optional[QuantConfig] = None):
actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms]
m = sum(aligned_ms)
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \
else torch.empty(m, device='cuda', dtype=torch.int32)
d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
@@ -187,48 +298,61 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n:
for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)):
actual_end = start + actual_m
aligned_end = start + aligned_m
m_indices[start:actual_end] = i
m_indices[actual_end:aligned_end] = -1
ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t()
if use_psum_layout:
grouped_layout[i] = actual_end
else:
grouped_layout[start: actual_end] = i
grouped_layout[actual_end: aligned_end] = -1
a[actual_end: aligned_end] = 0
ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t()
start = aligned_end
ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d)
if use_bf16:
b = b if major_b.is_k_major() else b.mT.contiguous().mT
return m, a, b, m_indices, d, ref_d
return m, a, b, grouped_layout, d, ref_d
assert major_a.is_k_major()
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
quant_config = QuantConfig() if quant_config is None else quant_config
a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True)
return m, a, b, grouped_layout, d, ref_d
def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor):
num_groups, max_m, _ = x.size()
x_psum = torch.empty_like(x).view(num_groups * max_m, -1)
last_psum_m = 0
for i in range(num_groups):
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1])
return m, a_fp8, b_fp8, m_indices, d, ref_d
x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m]
last_psum_m = align(psum_m[i], 128)
return x_psum
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int,
use_ue8m0: bool = False, use_bf16: bool = False):
use_ue8m0: bool = False, use_bf16: bool = False,
use_psum_layout: bool = False,
quant_config: Optional[QuantConfig] = None):
a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
ref_d = torch.einsum('gmk,gnk->gmn', a, b)
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j]
assert masked_m.amax().item() <= max_m
if use_bf16:
return a, b, masked_m, d, ref_d
return a, b, masked_m, psum_m, d, ref_d
a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float))
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
for i in range(num_groups):
a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0)
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
quant_config = QuantConfig() if quant_config is None else quant_config
a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0)
b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True)
return a_fp8, b_fp8, masked_m, d, ref_d
return a, b, masked_m, psum_m, d, ref_d
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int],

View File

@@ -1,12 +1,14 @@
import dataclasses
import random
import torch
from typing import Tuple
from typing import Tuple, List
import deep_gemm
from deep_gemm.testing import (
bench_kineto,
calc_diff, count_bytes,
ignore_env, get_arch_major
ignore_env, get_arch_major,
test_filter
)
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
@@ -154,7 +156,7 @@ def test_mqa_logits():
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"
assert diff < 1e-3, f'{diff=}'
else:
ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True)
@@ -204,8 +206,6 @@ def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
def test_paged_mqa_logits():
# TODO: fully refactor with PyTest
print('Testing FP8 Paged MQA Logits:')
max_model_len = 111 * 1000
for is_context_lens_2d in (False, True):
@@ -264,7 +264,7 @@ def test_paged_mqa_logits():
else:
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True),
('fp8_paged_mqa_logits', 'clean_logits'))
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: '
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, '
f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s', end='')
@@ -273,6 +273,8 @@ def test_paged_mqa_logits():
print()
if __name__ == '__main__':
torch.manual_seed(0)
random.seed(0)

View File

@@ -9,7 +9,7 @@ from deep_gemm.testing import (
calc_diff, count_bytes
)
from generators import (
get_arch_major,
get_arch_major, layout_masked_to_psum, align,
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
)
@@ -18,11 +18,7 @@ from generators import (
def test_gemm() -> None:
print('Testing GEMM:')
scores = []
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16):
# TODO: support accumulation for SM90 BF16 GEMM
if get_arch_major() == 9 and accumulate:
continue
for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
@@ -56,29 +52,30 @@ def test_gemm() -> None:
def test_m_grouped_gemm_contiguous() -> None:
print('Testing m-grouped contiguous GEMM:')
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16):
for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(torch.bfloat16):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
for test_alias in (False, True):
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True)
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
use_bf16=True, use_psum_layout=use_psum_layout)
func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
if test_alias:
assert major_a.is_k_major()
b = b if major_b.is_k_major() else b.mT
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, m_indices)
d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d)
getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}'
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True)
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
use_bf16=True, use_psum_layout=use_psum_layout)
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices)
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, use_psum_layout=use_psum_layout)
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): '
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, psum={use_psum_layout}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
@@ -89,29 +86,52 @@ def test_m_grouped_gemm_masked() -> None:
print('Testing m-grouped masked GEMM:')
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16):
# Test correctness
for i in range(10):
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
for _, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16):
num_tests = 8
sum_t, max_t = 0, 0
sum_ops, sum_bytes = 0, 0
for i in range(num_tests):
a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k,
use_bf16=True, use_psum_layout=use_psum_layout)
if use_psum_layout:
a_psum = layout_masked_to_psum(a, psum_m)
d_psum = layout_masked_to_psum(d, psum_m)
# noinspection PyShadowingNames
def test_func():
if use_psum_layout:
deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a_psum, b, d_psum, psum_m,
use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group)
else:
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
test_func()
for j in range(num_groups):
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
if masked_m[j].item() == 0:
continue
if use_psum_layout:
d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]]
else:
d_slice = d[j, :masked_m[j].item()]
diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()])
assert diff < 1e-5, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
# Construct full cases
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True)
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group)
# Test performance with fixed shapes
valid_m = masked_m.sum().item()
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
# Test performance with fixed shapes
valid_m = masked_m.sum().item()
t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): '
f'{t * 1e6:4.0f} us | '
f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')
sum_t += t
max_t = max(max_t, t)
sum_ops += 2 * valid_m * n * k
sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)
print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, '
f'psum={1 if use_psum_layout else 0}): '
f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | '
f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | '
f'{sum_bytes / sum_t / 1e9:4.0f} GB/s')
print()
@@ -148,7 +168,7 @@ def test_k_grouped_gemm_contiguous() -> None:
def test_cublaslt_gemm() -> None:
print('Testing cuBLASLt GEMM:')
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16):
for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
@@ -159,7 +179,8 @@ def test_cublaslt_gemm() -> None:
diff = calc_diff(d, ref_d)
assert diff < 6e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})'
t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,)
t_nvjet, t_gemv, t_gemm = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'gemv', 'gemm'), suppress_kineto_output=True)
t = t_nvjet + t_gemv + t_gemm
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'{t * 1e6:5.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '

View File

@@ -80,7 +80,6 @@ def test_bhd_hdr_bhr():
print()
@test_filter(lambda: get_arch_major() >= 10)
def test_fp8_bhr_hdr_bhd(use_ue8m0: bool = True):
print('Testing FP8 "bhr, hdr -> bhd":')
for h, r, d in [(8, 4096, 1024)]:

View File

@@ -1,175 +0,0 @@
import copy
import numpy as np
import random
import torch
import deep_gemm
from deep_gemm.testing import (
bench_kineto,
calc_diff, count_bytes,
ignore_env, get_arch_major
)
from generators import (
KernelType, get_ue8m0_usage,
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
)
@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9)
def test_gemm() -> None:
print('Testing GEMM:')
scores = []
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
acc_opt = f'acc={int(accumulate)}'
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None
for test_alias in (False, True):
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}'
if test_alias:
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe)
diff = calc_diff(d, ref_d)
assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
f'{diff:.5f}, alias={test_alias}')
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0)
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe),
'fp8_gemm', suppress_kineto_output=True)
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True)
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
if cublas_t > 0:
scores.append((cublas_t + split_k_t) / t)
print(f"Average speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n")
def test_m_grouped_gemm_contiguous() -> None:
print('Testing m-grouped contiguous GEMM:')
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
for test_alias in (False, True):
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
if test_alias:
assert major_a.is_k_major()
b = b if major_b.is_k_major() else (b[0].mT, b[1].mT)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}'
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_masked() -> None:
print('Testing m-grouped masked GEMM:')
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn):
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
# Test correctness
for i in range(10):
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
for j in range(num_groups):
if masked_m[j].item() == 0:
continue
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
# Construct full cases
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
# Test performance with fixed shapes
valid_m = masked_m.sum().item()
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')
print()
def test_k_grouped_gemm_contiguous() -> None:
print('Testing k-grouped contiguous GEMM:')
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn):
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
for test_empty_groups in (False, True):
new_ks = copy.deepcopy(ks)
if test_empty_groups and len(ks) > 1:
new_ks[random.randint(0, num_groups - 1)] = 0
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
# Test performance
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
# noinspection PyShadowingNames
def test_func():
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_k_grouped_gemm_contiguous()

207
tests/test_fp8_fp4.py Normal file
View File

@@ -0,0 +1,207 @@
import copy
import numpy as np
import random
import torch
import deep_gemm
from deep_gemm.testing import (
bench_kineto,
calc_diff, count_bytes,
ignore_env, get_arch_major
)
from generators import (
KernelType, get_ue8m0_usage, layout_masked_to_psum, align,
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
)
@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9)
def test_gemm() -> None:
print('Testing GEMM:')
scores = []
for kernel_type, quant_config, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
acc_opt = f'acc={int(accumulate)}'
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=(kernel_type.is_1d1d() and accumulate))
for test_alias in (False, True):
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config)
func_name = f'fp8_fp4_gemm_{major_opt.lower() if test_alias else "nt"}'
if test_alias:
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
diff = calc_diff(d, ref_d)
assert diff < quant_config.max_diff(), (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
f'{diff:.5f}, alias={test_alias}')
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config)
t = bench_kineto(lambda: deep_gemm.fp8_fp4_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b),
'fp8_gemm', suppress_kineto_output=True)
cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) \
if not quant_config.is_fp4_a and not quant_config.is_fp4_b else (0, 0)
print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | '
f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS')
if cublas_t > 0:
scores.append((cublas_t + split_k_t) / t)
print(f"Average FP8xFP8 GEMM speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n")
def test_m_grouped_gemm_contiguous() -> None:
print('Testing m-grouped contiguous GEMM:')
for kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
recipe, recipe_a, recipe_b = quant_config.get_recipes()
for test_alias in (False, True):
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout,
quant_config=quant_config)
func_name = f"m_grouped_fp8_fp4_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
if test_alias:
assert major_a.is_k_major()
b = b if major_b.is_k_major() else (b[0].mT, b[1].mT)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
diff = calc_diff(d, ref_d)
assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}'
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout,
quant_config=quant_config)
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}, psum={use_psum_layout}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_masked() -> None:
print('Testing m-grouped masked GEMM:')
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
for kernel_type, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn):
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
recipe, recipe_a, recipe_b = quant_config.get_recipes()
num_tests = 8
sum_t, max_t = 0, 0
sum_ops, sum_bytes = 0, 0
for i in range(num_tests):
a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k,
use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout,
quant_config=quant_config)
if use_psum_layout:
a_psum = (layout_masked_to_psum(a[0], psum_m), layout_masked_to_psum(a[1], psum_m))
d_psum = layout_masked_to_psum(d, psum_m)
# noinspection PyShadowingNames
def test_func():
if use_psum_layout:
deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, disable_ue8m0_cast=disable_ue8m0_cast,
use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
else:
deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast,
recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b)
test_func()
for j in range(num_groups):
if masked_m[j].item() == 0:
continue
if use_psum_layout:
d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]]
else:
d_slice = d[j, :masked_m[j].item()]
diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()])
assert diff < quant_config.max_diff(), f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
# Test performance with fixed shapes
valid_m = masked_m.sum().item()
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
sum_t += t
max_t = max(max_t, t)
sum_ops += 2 * valid_m * n * k
sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)
print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, '
f'{kernel_opt}, psum={1 if use_psum_layout else 0}): '
f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | '
f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | '
f'{sum_bytes / sum_t / 1e9:4.0f} GB/s')
print()
@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9)
def test_k_grouped_gemm_contiguous() -> None:
print('Testing k-grouped contiguous GEMM:')
k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \
else deep_gemm.k_grouped_fp8_gemm_tn_contiguous
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn):
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
for test_empty_groups in (False, True):
new_ks = copy.deepcopy(ks)
if test_empty_groups and len(ks) > 1:
new_ks[random.randint(0, num_groups - 1)] = 0
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0)
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}'
# Test performance
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
# noinspection PyShadowingNames
def test_func():
k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_k_grouped_gemm_contiguous()

View File

@@ -0,0 +1,57 @@
import torch
import random
import deep_gemm
from deep_gemm.testing import (
test_filter,
bench_kineto,
calc_diff, count_bytes
)
from deep_gemm.utils import align
from generators import get_arch_major
@test_filter(lambda: get_arch_major() >= 9)
def test_hc_prenorm_gemm() -> None:
# Needs TF32 precision for PyTorch GEMMs
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
print('Testing hyperconnection prenorm GEMM:')
for m in (13, 137, 4096, 8192):
for n, k in [(24, 28672), (24, 7680), (24, 7168)]:
for num_splits in [None, 16]:
a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda')
b = torch.randn((n, k), dtype=torch.float, device='cuda')
d = torch.empty((m, n), dtype=torch.float, device='cuda') if num_splits is None else \
torch.empty((num_splits, m, n), dtype=torch.float, device='cuda')
s = torch.empty((m, ), dtype=torch.float, device='cuda') if num_splits is None else \
torch.empty((num_splits, m), dtype=torch.float, device='cuda')
deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits)
final_d = d if num_splits is None else d.sum(0)
final_s = s if num_splits is None else s.sum(0)
ref_d = a.float() @ b.T
ref_s = a.float().square().sum(-1)
diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s))
assert diff < 1e-8, f'{m=}, {n=}, {k=}, {diff:.10f}'
t = bench_kineto(lambda: deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits), 'tf32_hc_prenorm_gemm', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, num_splits={(num_splits or 0):2}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, d, s) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.manual_seed(0)
random.seed(0)
print('Library path:')
print(f' > {deep_gemm.__path__}\n')
test_hc_prenorm_gemm()

View File

@@ -13,7 +13,7 @@ from generators import (
def test_m_grouped_gemm_contiguous_tl() -> None:
print('Testing m-grouped contiguous Triton GEMM:')
for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16):
for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, _ in enumerate_m_grouped_contiguous(torch.bfloat16):
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'