Files
DeepGEMM/csrc/apis/einsum.hpp
Chenggang Zhao 7f2a703ed5 [Public release 26/04] Introducing Mega MoE, FP4 Indexer and other features/fixes (#304)
* Merge with private repo

* Update README

* Update README

* Update README

* Add PyTorch requirements

* Fix sync scopes for MQA logits (#256)

* Update README
2026-04-17 09:45:14 +08:00

232 lines
9.9 KiB
C++

#pragma once
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/layout.hpp"
#include "../utils/compatibility.hpp"
#include "gemm.hpp"
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
#endif
namespace deep_gemm::einsum {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d,
const std::optional<torch::Tensor>& c) {
// Currently FP32 only support the accumulated expression
if (d.scalar_type() == torch::kFloat) {
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides());
} else {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(not c.has_value());
const auto workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(),
c10::cuda::getCurrentCUDAStream()));
bmk_bnk_mn(a, b, workspace, workspace);
// This line has an implicit FP32-to-BF16 casting
d.copy_(workspace);
return;
}
DG_HOST_ASSERT(a.is_contiguous());
DG_HOST_ASSERT(b.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
const auto [s , m, k ] = get_shape<3>(a);
const auto [s_, n, k_] = get_shape<3>(b);
DG_HOST_ASSERT(s == s_ and k == k_);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else if (arch_major == 10) {
sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) {
const auto [b , h , r ] = get_shape<3>(A);
const auto [h_, d , r_] = get_shape<3>(B);
const auto [b_, h__, d_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (use_cublaslt) {
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
sm90_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 10) {
sm100_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) {
const auto [b , h , d ] = get_shape<3>(A);
const auto [h_, d_ , r ] = get_shape<3>(B);
const auto [b_, h__, r_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (use_cublaslt) {
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
sm90_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 10) {
sm100_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void einsum(const std::string& expr,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const bool& use_cublaslt) {
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
if (c.has_value()) {
DG_HOST_ASSERT(c->scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
}
// Some hardcoded Einstein sum kernels
// TODO: support any expression
// TODO: canonicalize expression
if (expr == "bmk,bnk->mn") {
DG_HOST_ASSERT(not use_cublaslt);
bmk_bnk_mn(a, b, d, c);
} else if (expr == "bhr,hdr->bhd") {
DG_HOST_ASSERT(not c.has_value());
bhr_hdr_bhd(a, b, d, use_cublaslt);
} else if (expr == "bhd,hdr->bhr") {
DG_HOST_ASSERT(not c.has_value());
bhd_hdr_bhr(a, b, d, use_cublaslt);
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
}
}
static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims) {
// Shape must be `[B, M, K] @ [B, N, K].T`
const auto major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
const auto major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
DG_HOST_ASSERT(a.stride(-1) == 1 or a.stride(-2) == 1);
DG_HOST_ASSERT(b.stride(-1) == 1 or b.stride(-2) == 1);
DG_HOST_ASSERT(d.stride(-1) == 1);
// Type and shape checks
const auto [batch_size , m , k ] = get_shape<3>(a);
const auto [batch_size_ , n , k_] = get_shape<3>(b);
const auto [batch_size__, m_, n_] = get_shape<3>(d);
DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size_);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Early return for trivial cases
if (batch_size == 0 or gemm::early_return(m, n, k, d, c))
return;
// Transform scaling factors
const auto [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims);
} else {
const auto major_sfb = get_major_type_ab(sfb);
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
}
static void fp8_einsum(const std::string& expr,
const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe) {
// Some hardcoded Einstein sum kernels
const auto arch_major = device_runtime->get_arch_major();
if (expr == "bhr,hdr->bhd") {
// Permute dims to satisfy the order of (batch_size, m, n, k)
// (batch_size, m, n, k): (h, b, d, r)
const auto perm_a = a.first.permute({1, 0, 2});
const auto perm_sfa = a.second.permute({1, 0, 2});
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,hdr->bhr" and arch_major == 10) {
// (batch_size, m, n, k): (h, b, r, d)
const auto perm_a = a.first.permute({1, 0, 2});
const auto perm_sfa = a.second.permute({1, 0, 2});
const auto perm_b = b.first.permute({0, 2, 1});
const auto perm_sfb = b.second.permute({0, 2, 1});
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,bhr->hdr" and arch_major == 10) {
// (batch_size, m, n, k): (h, d, r, b)
const auto perm_a = a.first.permute({1, 2, 0});
const auto perm_sfa = a.second.permute({1, 2, 0});
const auto perm_b = b.first.permute({1, 2, 0});
const auto perm_sfb = b.second.permute({1, 2, 0});
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn");
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
}
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("einsum", &einsum,
py::arg("expr"), py::arg("a"), py::arg("b"),
py::arg("d"), py::arg("c") = std::nullopt,
py::arg("use_cublaslt") = false);
m.def("fp8_einsum", &fp8_einsum,
py::arg("expr"), py::arg("a"), py::arg("b"),
py::arg("d"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 128, 128));
#endif
}
} // namespace deep_gemm::einsum