* Merge with private repo * Update README * Update README * Update README * Add PyTorch requirements * Fix sync scopes for MQA logits (#256) * Update README
71 lines
2.7 KiB
C++
71 lines
2.7 KiB
C++
#pragma once
|
|
|
|
#include "../utils/compatibility.hpp"
|
|
|
|
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
|
#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp"
|
|
#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp"
|
|
#endif
|
|
|
|
namespace deep_gemm::hyperconnection {
|
|
|
|
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
|
static void tf32_hc_prenorm_gemm(const torch::Tensor& a,
|
|
const torch::Tensor& b,
|
|
const torch::Tensor& d,
|
|
const torch::Tensor& sqr_sum,
|
|
const std::optional<int>& num_splits) {
|
|
// A and B must be K-major, D must be N-major
|
|
DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K);
|
|
DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K);
|
|
check_major_type_cd(d);
|
|
|
|
// S must be contiguous
|
|
DG_HOST_ASSERT(sqr_sum.is_contiguous());
|
|
|
|
// Type and shape checks
|
|
const auto [m, k ] = get_shape<2>(a);
|
|
const auto [n, k_] = get_shape<2>(b);
|
|
if (num_splits.has_value()) {
|
|
const auto [num_splits_, m_, n_] = get_shape<3>(d);
|
|
const auto [num_splits__, m__] = get_shape<2>(sqr_sum);
|
|
DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1);
|
|
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
|
} else {
|
|
const auto [m_, n_] = get_shape<2>(d);
|
|
const auto [m__] = get_shape<1>(sqr_sum);
|
|
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
|
}
|
|
DG_HOST_ASSERT(n > 0 and k > 0);
|
|
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
|
|
DG_HOST_ASSERT(b.scalar_type() == torch::kFloat);
|
|
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
|
DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat);
|
|
|
|
// Do nothing if the problem is empty
|
|
if (m == 0)
|
|
return;
|
|
|
|
// Dispatch into different implements
|
|
const auto arch_major = device_runtime->get_arch_major();
|
|
if (arch_major == 9) {
|
|
sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
|
|
} else if (arch_major == 10) {
|
|
sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
|
|
} else {
|
|
DG_HOST_UNREACHABLE("Unsupported architecture");
|
|
}
|
|
}
|
|
|
|
#endif
|
|
|
|
static void register_apis(pybind11::module_& m) {
|
|
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
|
|
m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm,
|
|
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"),
|
|
py::arg("num_splits") = std::nullopt);
|
|
#endif
|
|
}
|
|
|
|
} // namespace deep_gemm::hyperconnection
|