Files
DeepGEMM/csrc/apis/layout.hpp
Chenggang Zhao 7f2a703ed5 [Public release 26/04] Introducing Mega MoE, FP4 Indexer and other features/fixes (#304)
* Merge with private repo

* Update README

* Update README

* Update README

* Add PyTorch requirements

* Fix sync scopes for MQA logits (#256)

* Update README
2026-04-17 09:45:14 +08:00

144 lines
7.2 KiB
C++

#pragma once
#include "../jit_kernels/heuristics/runtime.hpp"
#include "../utils/layout.hpp"
#include "../utils/compatibility.hpp"
#if DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/smxx_layout.hpp"
#endif
namespace deep_gemm::layout {
#if DG_TENSORMAP_COMPATIBLE
static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const std::variant<std::tuple<int, int, int>,
std::tuple<int, int>>& recipe,
const std::optional<int>& num_groups,
const std::optional<bool>& is_sfa,
const bool& disable_ue8m0_cast) {
const auto arch_major = device_runtime->get_arch_major();
// Get granularity MN/K from recipe
int gran_mn, gran_k;
if (auto p = std::get_if<std::tuple<int, int, int>>(&recipe)) {
DG_HOST_ASSERT(is_sfa.has_value());
gran_mn = is_sfa.value() ? std::get<0>(*p) : std::get<1>(*p);
gran_k = std::get<2>(*p);
} else if (auto p = std::get_if<std::tuple<int, int>>(&recipe)) {
DG_HOST_ASSERT(not is_sfa.has_value());
std::tie(gran_mn, gran_k) = *p;
} else {
DG_HOST_UNREACHABLE("Invalid recipe");
}
// Pre-transform checks
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return get_mn_major_tma_aligned_tensor(sf);
// (FP32, 128, 128) on SM90: no need to transform, check SFB requirements
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
// (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
const auto broadcasted = gran_mn == 1 ? sf :
sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn));
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
}
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
DG_HOST_UNREACHABLE("Unknown SF transformation");
}
static std::tuple<torch::Tensor, torch::Tensor, int, int> transform_sf_pair_into_required_layout(
const torch::Tensor& sfa, const torch::Tensor& sfb,
const int& m, const int& n, const int& k,
std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::optional<int>& num_groups_a,
const std::optional<int>& num_groups_b,
const bool& disable_ue8m0_cast = false) {
// Use default recipe, if none is specified
if (not recipe_a.has_value() and not recipe.has_value())
recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type());
// Must be either 'recipe' or the 'recipe_a' + 'recipe_b' pair.
DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value());
DG_HOST_ASSERT(recipe_a.has_value() != recipe.has_value());
// Transform SFA and SFB layout
const auto transformed_sfa = recipe.has_value() ? transform_sf_into_required_layout(sfa, m, k, recipe.value(), num_groups_a, true, disable_ue8m0_cast)
: transform_sf_into_required_layout(sfa, m, k, recipe_a.value(), num_groups_a, std::nullopt, disable_ue8m0_cast);
const auto transformed_sfb = recipe.has_value() ? transform_sf_into_required_layout(sfb, n, k, recipe.value(), num_groups_b, false, disable_ue8m0_cast)
: transform_sf_into_required_layout(sfb, n, k, recipe_b.value(), num_groups_b, std::nullopt, disable_ue8m0_cast);
const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value());
const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value());
return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b);
}
static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::tuple<int, int, int>& recipe) {
DG_HOST_ASSERT(sf.dim() == 2);
DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1);
const int gran_k = std::get<2>(recipe);
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
const auto arch_major = device_runtime->get_arch_major();
// FP32 on SM90
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
return get_mn_major_tma_aligned_tensor(sf);
// FP32 on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks, gran_k);
// INT on SM100
if (sf.scalar_type() == torch::kInt and arch_major == 10)
DG_HOST_UNREACHABLE("Unimplemented");
DG_HOST_UNREACHABLE("Unknown cases");
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
py::arg("num_groups") = std::nullopt,
py::arg("is_sfa") = std::nullopt,
py::arg("disable_ue8m0_cast") = false);
m.def("get_tma_aligned_size", &get_tma_aligned_size);
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
#endif
m.def("set_mk_alignment_for_contiguous_layout", [&](const int& new_value) {
heuristics_runtime->set_mk_alignment_for_contiguous_layout(new_value);
});
m.def("get_mk_alignment_for_contiguous_layout", [&]() {
return heuristics_runtime->get_mk_alignment_for_contiguous_layout();
});
m.def("get_theoretical_mk_alignment_for_contiguous_layout", [&](const std::optional<int>& expected_m) {
return heuristics_runtime->get_theoretical_mk_alignment_for_contiguous_layout(expected_m);
}, py::arg("expected_m") = std::nullopt);
}
} // namespace deep_gemm::layout