Add more GPU architectures support (#112)
* Add more GPU architectures support * Update layout.py * Optimize performance, Add SM90 support, Add 1D2D SM100 support * Add fmtlib submodule at commit 553ec11 --------- Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
25
csrc/utils/math.hpp
Normal file
25
csrc/utils/math.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <typename T>
|
||||
static T ceil_div(const T& a, const T& b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr T align(const T& a, const T& b) {
|
||||
return ceil_div(a, b) * b;
|
||||
}
|
||||
|
||||
static int get_tma_aligned_size(const int& x, const int& element_size) {
|
||||
constexpr int kNumTMAAlignmentBytes = 16;
|
||||
DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0);
|
||||
return align(x, kNumTMAAlignmentBytes / element_size);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
Reference in New Issue
Block a user