Add more GPU architectures support (#112)
* Add more GPU architectures support * Update layout.py * Optimize performance, Add SM90 support, Add 1D2D SM100 support * Add fmtlib submodule at commit 553ec11 --------- Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
58
csrc/utils/exception.hpp
Normal file
58
csrc/utils/exception.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class DGException final : public std::exception {
|
||||
std::string message = {};
|
||||
|
||||
public:
|
||||
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
|
||||
}
|
||||
|
||||
const char *what() const noexcept override {
|
||||
return message.c_str();
|
||||
}
|
||||
};
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
throw DGException("Assertion", __FILE__, __LINE__, #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_UNREACHABLE
|
||||
#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason))
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUDA_DRIVER_CHECK
|
||||
#define DG_CUDA_DRIVER_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != CUDA_SUCCESS) { \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ""); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUDA_RUNTIME_CHECK
|
||||
#define DG_CUDA_RUNTIME_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != cudaSuccess) { \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast<int>(e))); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
} // namespace deep_gemm
|
||||
6
csrc/utils/format.hpp
Normal file
6
csrc/utils/format.hpp
Normal file
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
// Just a wrapper for the `fmt` headers
|
||||
#define FMT_HEADER_ONLY
|
||||
#include <fmt/base.h>
|
||||
#include <fmt/format.h>
|
||||
35
csrc/utils/hash.hpp
Normal file
35
csrc/utils/hash.hpp
Normal file
@@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static uint64_t fnv1a(const std::string& data, const uint64_t& seed) {
|
||||
uint64_t h = seed;
|
||||
const uint64_t& prime = 0x100000001b3ull;
|
||||
for (const char& c: data) {
|
||||
h ^= static_cast<uint8_t>(c);
|
||||
h *= prime;
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
static std::string get_hex_digest(const std::string& data) {
|
||||
const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull);
|
||||
const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull);
|
||||
|
||||
// Split-mix 64
|
||||
const auto& split_mix = [](uint64_t z) {
|
||||
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull;
|
||||
z = (z ^ (z >> 27)) * 0x94d049bb133111ebull;
|
||||
return z ^ (z >> 31);
|
||||
};
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << std::hex << std::setfill('0')
|
||||
<< std::setw(16) << split_mix(state_0)
|
||||
<< std::setw(16) << split_mix(state_1);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
100
csrc/utils/layout.hpp
Normal file
100
csrc/utils/layout.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm100_umma.hpp>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "math.hpp"
|
||||
#include "exception.hpp"
|
||||
#include "../jit/device_runtime.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// Major-ness stuffs
|
||||
static void major_check(const torch::Tensor& t) {
|
||||
const auto dim = t.dim();
|
||||
DG_HOST_ASSERT(dim == 2 or dim == 3);
|
||||
if (dim == 3)
|
||||
DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1));
|
||||
DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1);
|
||||
}
|
||||
|
||||
static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) {
|
||||
major_check(t);
|
||||
return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
|
||||
}
|
||||
|
||||
static void check_major_type_cd(const torch::Tensor& t) {
|
||||
// NOTES: the library only supports row-major output layouts
|
||||
major_check(t);
|
||||
DG_HOST_ASSERT(t.stride(-1) == 1);
|
||||
}
|
||||
|
||||
static bool fp8_requires_k_major() {
|
||||
return device_runtime->get_arch_major() == 9;
|
||||
}
|
||||
|
||||
// Tensor utils
|
||||
template <int N>
|
||||
static auto get_shape(const torch::Tensor& t) {
|
||||
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
|
||||
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
|
||||
}(std::make_index_sequence<N>());
|
||||
}
|
||||
|
||||
// Recipe
|
||||
static std::tuple<int, int, int>
|
||||
get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) {
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat);
|
||||
return {1, 128, 128};
|
||||
} else if (arch_major == 10) {
|
||||
DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt);
|
||||
return sfb_dtype == torch::kFloat ?
|
||||
std::make_tuple(1, 128, 128): // Legacy format or 1D2D kernels
|
||||
std::make_tuple(1, 1, 128); // 1D1D kernels
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unknown recipe");
|
||||
}
|
||||
|
||||
// SF layouts
|
||||
static torch::Tensor check_sf_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const int& gran_mn, const int& gran_k,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& tma_stride_check = false,
|
||||
const bool& contiguous_check = false,
|
||||
const std::optional<torch::ScalarType>& type_check = std::nullopt) {
|
||||
// Type check
|
||||
if (type_check.has_value())
|
||||
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
|
||||
|
||||
// Always do shape checks
|
||||
const auto& sf_dtype = sf.scalar_type();
|
||||
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
|
||||
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
|
||||
if (num_groups.has_value())
|
||||
DG_HOST_ASSERT(sf.size(-3) == num_groups.value());
|
||||
DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn));
|
||||
DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)));
|
||||
|
||||
// TMA stride checks: TMA aligned and MN-major
|
||||
if (tma_stride_check) {
|
||||
if (num_groups.has_value())
|
||||
DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1));
|
||||
DG_HOST_ASSERT(sf.stride(-2) == 1);
|
||||
DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()));
|
||||
}
|
||||
|
||||
// Hopper SFB must be contiguous
|
||||
if (contiguous_check)
|
||||
DG_HOST_ASSERT(sf.is_contiguous());
|
||||
return sf;
|
||||
}
|
||||
|
||||
// Value matrix layout
|
||||
static int get_mk_alignment_for_contiguous_layout() {
|
||||
return 128;
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
25
csrc/utils/math.hpp
Normal file
25
csrc/utils/math.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <typename T>
|
||||
static T ceil_div(const T& a, const T& b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr T align(const T& a, const T& b) {
|
||||
return ceil_div(a, b) * b;
|
||||
}
|
||||
|
||||
static int get_tma_aligned_size(const int& x, const int& element_size) {
|
||||
constexpr int kNumTMAAlignmentBytes = 16;
|
||||
DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0);
|
||||
return align(x, kNumTMAAlignmentBytes / element_size);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
70
csrc/utils/system.hpp
Normal file
70
csrc/utils/system.hpp
Normal file
@@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
template <typename dtype_t>
|
||||
static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) {
|
||||
const auto& c_str = std::getenv(name.c_str());
|
||||
if (c_str == nullptr)
|
||||
return default_value;
|
||||
|
||||
// Read the env and convert to the desired type
|
||||
if constexpr (std::is_same_v<dtype_t, std::string>) {
|
||||
return std::string(c_str);
|
||||
} else if constexpr (std::is_same_v<dtype_t, int>) {
|
||||
int value;
|
||||
std::sscanf(c_str, "%d", &value);
|
||||
return value;
|
||||
} else {
|
||||
DG_HOST_ASSERT(false and "Unexpected type");
|
||||
}
|
||||
}
|
||||
|
||||
static std::tuple<int, std::string> call_external_command(std::string command) {
|
||||
command = command + " 2>&1";
|
||||
const auto& deleter = [](FILE* f) { if (f) pclose(f); };
|
||||
std::unique_ptr<FILE, decltype(deleter)> pipe(popen(command.c_str(), "r"), deleter);
|
||||
DG_HOST_ASSERT(pipe != nullptr);
|
||||
|
||||
std::array<char, 512> buffer;
|
||||
std::string output;
|
||||
while (fgets(buffer.data(), buffer.size(), pipe.get()))
|
||||
output += buffer.data();
|
||||
const auto& exit_code = WEXITSTATUS(pclose(pipe.release()));
|
||||
return {exit_code, output};
|
||||
}
|
||||
|
||||
static std::filesystem::path make_dirs(const std::filesystem::path& path) {
|
||||
// OK if existed
|
||||
std::error_code capture;
|
||||
const bool& created = std::filesystem::create_directories(path, capture);
|
||||
DG_HOST_ASSERT(created or capture.value() == 0);
|
||||
if (created and get_env<int>("DG_JIT_DEBUG"))
|
||||
printf("Create directory: %s\n", path.c_str());
|
||||
return path;
|
||||
}
|
||||
|
||||
static std::string get_uuid() {
|
||||
static std::random_device rd;
|
||||
static std::mt19937 gen([]() {
|
||||
return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
}());
|
||||
static std::uniform_int_distribution<uint32_t> dist;
|
||||
|
||||
std::stringstream ss;
|
||||
ss << getpid() << "-"
|
||||
<< std::hex << std::setfill('0')
|
||||
<< std::setw(8) << dist(gen) << "-"
|
||||
<< std::setw(8) << dist(gen) << "-"
|
||||
<< std::setw(8) << dist(gen);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // deep_gemm
|
||||
Reference in New Issue
Block a user