78 lines
3.7 KiB
C++
78 lines
3.7 KiB
C++
#pragma once
|
|
|
|
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
|
|
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
|
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
|
#include "../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
|
|
|
#include "layout.hpp"
|
|
|
|
namespace deep_gemm::attention {
|
|
|
|
static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
|
|
const std::pair<torch::Tensor, torch::Tensor>& b,
|
|
const torch::Tensor& d,
|
|
const std::tuple<int, int, int> &head_splits,
|
|
std::optional<std::tuple<int, int, int>> recipe,
|
|
const std::string& compiled_dims,
|
|
const bool& disable_ue8m0_cast) {
|
|
// Shape must be `[M, K] @ [N, K].T`
|
|
const auto& major_a = get_major_type_ab(a.first);
|
|
const auto& major_b = get_major_type_ab(b.first);
|
|
if (fp8_requires_k_major()) {
|
|
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
|
}
|
|
|
|
// D must be N-major
|
|
check_major_type_cd(d);
|
|
|
|
// Type and shape checks
|
|
const auto& [m , k ] = get_shape<2>(a.first);
|
|
const auto& [n , k_] = get_shape<2>(b.first);
|
|
const auto& [m_, n_] = get_shape<2>(d);
|
|
DG_HOST_ASSERT(m == m_ and k == k_);
|
|
DG_HOST_ASSERT(n > 0 and k > 0);
|
|
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
|
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
|
|
|
// Check head splits and N
|
|
const auto& [left, mid, right] = head_splits;
|
|
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
|
|
|
|
// Do nothing if the problem is empty
|
|
if (m == 0)
|
|
return;
|
|
|
|
// Transform SFA and SFB into compute-required layout
|
|
if (not recipe.has_value())
|
|
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
|
DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128));
|
|
const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
|
const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
|
|
|
// Dispatch into different implements
|
|
const auto& arch_major = device_runtime->get_arch_major();
|
|
const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
|
|
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
|
|
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
|
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
|
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
|
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
|
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, compiled_dims, epilogue_type);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
|
}
|
|
}
|
|
|
|
static void register_apis(pybind11::module_& m) {
|
|
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
|
|
py::arg("recipe") = std::nullopt,
|
|
py::arg("compiled_dims") = "nk",
|
|
py::arg("disable_ue8m0_cast") = false);
|
|
}
|
|
|
|
} // namespace deep_gemm::attention
|