Add more GPU architectures support (#112)

* Add more GPU architectures support

* Update layout.py

* Optimize performance, Add SM90 support, Add 1D2D SM100 support

* Add fmtlib submodule at commit 553ec11

---------

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
Ray Wang
2025-07-18 11:32:22 +08:00
committed by GitHub
parent 03d0be3d2d
commit 9da4a23561
67 changed files with 5586 additions and 2965 deletions

58
csrc/utils/exception.hpp Normal file
View File

@@ -0,0 +1,58 @@
#pragma once
#include <exception>
#include <string>
namespace deep_gemm {
class DGException final : public std::exception {
std::string message = {};
public:
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
}
const char *what() const noexcept override {
return message.c_str();
}
};
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#endif
#ifndef DG_HOST_ASSERT
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
throw DGException("Assertion", __FILE__, __LINE__, #cond); \
} \
} while (0)
#endif
#ifndef DG_HOST_UNREACHABLE
#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason))
#endif
#ifndef DG_CUDA_DRIVER_CHECK
#define DG_CUDA_DRIVER_CHECK(cmd) \
do { \
const auto& e = (cmd); \
if (e != CUDA_SUCCESS) { \
throw DGException("CUDA driver", __FILE__, __LINE__, ""); \
} \
} while (0)
#endif
#ifndef DG_CUDA_RUNTIME_CHECK
#define DG_CUDA_RUNTIME_CHECK(cmd) \
do { \
const auto& e = (cmd); \
if (e != cudaSuccess) { \
throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast<int>(e))); \
} \
} while (0)
#endif
} // namespace deep_gemm

6
csrc/utils/format.hpp Normal file
View File

@@ -0,0 +1,6 @@
#pragma once
// Just a wrapper for the `fmt` headers
#define FMT_HEADER_ONLY
#include <fmt/base.h>
#include <fmt/format.h>

35
csrc/utils/hash.hpp Normal file
View File

@@ -0,0 +1,35 @@
#pragma once
#include <string>
namespace deep_gemm {
static uint64_t fnv1a(const std::string& data, const uint64_t& seed) {
uint64_t h = seed;
const uint64_t& prime = 0x100000001b3ull;
for (const char& c: data) {
h ^= static_cast<uint8_t>(c);
h *= prime;
}
return h;
}
static std::string get_hex_digest(const std::string& data) {
const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull);
const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull);
// Split-mix 64
const auto& split_mix = [](uint64_t z) {
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull;
z = (z ^ (z >> 27)) * 0x94d049bb133111ebull;
return z ^ (z >> 31);
};
std::ostringstream oss;
oss << std::hex << std::setfill('0')
<< std::setw(16) << split_mix(state_0)
<< std::setw(16) << split_mix(state_1);
return oss.str();
}
} // namespace deep_gemm

100
csrc/utils/layout.hpp Normal file
View File

@@ -0,0 +1,100 @@
#pragma once
#include <cute/arch/mma_sm100_umma.hpp>
#include <torch/python.h>
#include "math.hpp"
#include "exception.hpp"
#include "../jit/device_runtime.hpp"
namespace deep_gemm {
// Major-ness stuffs
static void major_check(const torch::Tensor& t) {
const auto dim = t.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
if (dim == 3)
DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1));
DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1);
}
static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) {
major_check(t);
return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
}
static void check_major_type_cd(const torch::Tensor& t) {
// NOTES: the library only supports row-major output layouts
major_check(t);
DG_HOST_ASSERT(t.stride(-1) == 1);
}
static bool fp8_requires_k_major() {
return device_runtime->get_arch_major() == 9;
}
// Tensor utils
template <int N>
static auto get_shape(const torch::Tensor& t) {
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
}(std::make_index_sequence<N>());
}
// Recipe
static std::tuple<int, int, int>
get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) {
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat);
return {1, 128, 128};
} else if (arch_major == 10) {
DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt);
return sfb_dtype == torch::kFloat ?
std::make_tuple(1, 128, 128): // Legacy format or 1D2D kernels
std::make_tuple(1, 1, 128); // 1D1D kernels
}
DG_HOST_UNREACHABLE("Unknown recipe");
}
// SF layouts
static torch::Tensor check_sf_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const int& gran_mn, const int& gran_k,
const std::optional<int>& num_groups,
const bool& tma_stride_check = false,
const bool& contiguous_check = false,
const std::optional<torch::ScalarType>& type_check = std::nullopt) {
// Type check
if (type_check.has_value())
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
// Always do shape checks
const auto& sf_dtype = sf.scalar_type();
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
if (num_groups.has_value())
DG_HOST_ASSERT(sf.size(-3) == num_groups.value());
DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn));
DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)));
// TMA stride checks: TMA aligned and MN-major
if (tma_stride_check) {
if (num_groups.has_value())
DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1));
DG_HOST_ASSERT(sf.stride(-2) == 1);
DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()));
}
// Hopper SFB must be contiguous
if (contiguous_check)
DG_HOST_ASSERT(sf.is_contiguous());
return sf;
}
// Value matrix layout
static int get_mk_alignment_for_contiguous_layout() {
return 128;
}
} // namespace deep_gemm

25
csrc/utils/math.hpp Normal file
View File

@@ -0,0 +1,25 @@
#pragma once
#include <torch/python.h>
#include "exception.hpp"
namespace deep_gemm {
template <typename T>
static T ceil_div(const T& a, const T& b) {
return (a + b - 1) / b;
}
template <typename T>
static constexpr T align(const T& a, const T& b) {
return ceil_div(a, b) * b;
}
static int get_tma_aligned_size(const int& x, const int& element_size) {
constexpr int kNumTMAAlignmentBytes = 16;
DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0);
return align(x, kNumTMAAlignmentBytes / element_size);
}
} // namespace deep_gemm

70
csrc/utils/system.hpp Normal file
View File

@@ -0,0 +1,70 @@
#pragma once
#include <random>
#include <string>
#include <memory>
#include "exception.hpp"
namespace deep_gemm {
// ReSharper disable once CppNotAllPathsReturnValue
template <typename dtype_t>
static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) {
const auto& c_str = std::getenv(name.c_str());
if (c_str == nullptr)
return default_value;
// Read the env and convert to the desired type
if constexpr (std::is_same_v<dtype_t, std::string>) {
return std::string(c_str);
} else if constexpr (std::is_same_v<dtype_t, int>) {
int value;
std::sscanf(c_str, "%d", &value);
return value;
} else {
DG_HOST_ASSERT(false and "Unexpected type");
}
}
static std::tuple<int, std::string> call_external_command(std::string command) {
command = command + " 2>&1";
const auto& deleter = [](FILE* f) { if (f) pclose(f); };
std::unique_ptr<FILE, decltype(deleter)> pipe(popen(command.c_str(), "r"), deleter);
DG_HOST_ASSERT(pipe != nullptr);
std::array<char, 512> buffer;
std::string output;
while (fgets(buffer.data(), buffer.size(), pipe.get()))
output += buffer.data();
const auto& exit_code = WEXITSTATUS(pclose(pipe.release()));
return {exit_code, output};
}
static std::filesystem::path make_dirs(const std::filesystem::path& path) {
// OK if existed
std::error_code capture;
const bool& created = std::filesystem::create_directories(path, capture);
DG_HOST_ASSERT(created or capture.value() == 0);
if (created and get_env<int>("DG_JIT_DEBUG"))
printf("Create directory: %s\n", path.c_str());
return path;
}
static std::string get_uuid() {
static std::random_device rd;
static std::mt19937 gen([]() {
return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count();
}());
static std::uniform_int_distribution<uint32_t> dist;
std::stringstream ss;
ss << getpid() << "-"
<< std::hex << std::setfill('0')
<< std::setw(8) << dist(gen) << "-"
<< std::setw(8) << dist(gen) << "-"
<< std::setw(8) << dist(gen);
return ss.str();
}
} // deep_gemm