Files
nvfp4-megamoe-kernel/third_party/DeepGEMM/csrc/utils/math.hpp

30 lines
705 B
C++

// TODO: merge this file with `math.cuh` (the device part)
#pragma once
#include <torch/python.h>
#include "exception.hpp"
namespace deep_gemm {
// TODO: use `torch::kFloat4_e2m1fn_x2`
constexpr auto kPackedFP4 = torch::kInt8;
template <typename T>
static T ceil_div(const T& a, const T& b) {
return (a + b - 1) / b;
}
template <typename T>
static constexpr T align(const T& a, const T& b) {
return ceil_div(a, b) * b;
}
static int get_tma_aligned_size(const int& x, const int& element_size) {
constexpr int kNumTMAAlignmentBytes = 16;
DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0);
return align(x, kNumTMAAlignmentBytes / element_size);
}
} // namespace deep_gemm