Add more GPU architectures support (#112)
* Add more GPU architectures support * Update layout.py * Optimize performance, Add SM90 support, Add 1D2D SM100 support * Add fmtlib submodule at commit 553ec11 --------- Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
13
csrc/indexing/main.cu
Normal file
13
csrc/indexing/main.cu
Normal file
@@ -0,0 +1,13 @@
|
||||
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
int main() {
|
||||
return 0;
|
||||
}
|
||||
31
csrc/jit/cache.hpp
Normal file
31
csrc/jit/cache.hpp
Normal file
@@ -0,0 +1,31 @@
|
||||
#pragma once
|
||||
|
||||
#include <filesystem>
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "kernel_runtime.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class KernelRuntimeCache {
|
||||
std::unordered_map<std::filesystem::path, std::shared_ptr<KernelRuntime>> cache;
|
||||
|
||||
public:
|
||||
// TODO: consider cache capacity
|
||||
KernelRuntimeCache() = default;
|
||||
|
||||
std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
|
||||
// Hit the runtime cache
|
||||
if (const auto& iterator = cache.find(dir_path); iterator != cache.end())
|
||||
return iterator->second;
|
||||
|
||||
if (KernelRuntime::check_validity(dir_path))
|
||||
return cache[dir_path] = std::make_shared<KernelRuntime>(dir_path);
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
static auto kernel_runtime_cache = std::make_shared<KernelRuntimeCache>();
|
||||
|
||||
} // namespace deep_gemm
|
||||
172
csrc/jit/compiler.hpp
Normal file
172
csrc/jit/compiler.hpp
Normal file
@@ -0,0 +1,172 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
#include "../utils/format.hpp"
|
||||
#include "../utils/hash.hpp"
|
||||
#include "../utils/system.hpp"
|
||||
#include "cache.hpp"
|
||||
#include "device_runtime.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class Compiler {
|
||||
std::string library_version;
|
||||
std::filesystem::path library_root_path;
|
||||
|
||||
std::string get_library_version() const {
|
||||
// Recursively walk through all subdirectories and update hash
|
||||
std::stringstream ss;
|
||||
for (const auto& entry: std::filesystem::recursive_directory_iterator(library_include_path / "deep_gemm")) {
|
||||
if (entry.is_regular_file() and entry.path().extension() == ".cuh") {
|
||||
std::ifstream file(entry.path(), std::ios::binary);
|
||||
std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator<char>());
|
||||
ss << content;
|
||||
}
|
||||
}
|
||||
return get_hex_digest(ss.str());
|
||||
}
|
||||
|
||||
public:
|
||||
std::string signature, flags;
|
||||
std::filesystem::path library_include_path;
|
||||
std::filesystem::path cache_dir_path;
|
||||
|
||||
explicit Compiler(const std::filesystem::path& library_root_path) {
|
||||
// Static library paths
|
||||
this->library_root_path = library_root_path;
|
||||
this->library_include_path = library_root_path / "include";
|
||||
this->library_version = get_library_version();
|
||||
|
||||
// Cache settings
|
||||
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
|
||||
if (const auto& env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty())
|
||||
cache_dir_path = env_cache_dir_path;
|
||||
|
||||
// The compiler flags applied to all derived compilers
|
||||
signature = "unknown-compiler";
|
||||
std::string ptxas_flags = "--ptxas-options=--register-usage-level=10";
|
||||
if (get_env<int>("DG_JIT_PTXAS_VERBOSE", 0))
|
||||
ptxas_flags += ",--verbose";
|
||||
flags = fmt::format("-std=c++20 --diag-suppress=39,161,174,177,186,940 {}", ptxas_flags);
|
||||
}
|
||||
|
||||
virtual ~Compiler() = default;
|
||||
|
||||
std::filesystem::path make_tmp_dir() const {
|
||||
return make_dirs(cache_dir_path / "tmp");
|
||||
}
|
||||
|
||||
std::filesystem::path get_tmp_file_path() const {
|
||||
return make_tmp_dir() / get_uuid();
|
||||
}
|
||||
|
||||
void put(const std::filesystem::path& path, const std::string& data) const {
|
||||
const auto tmp_file_path = get_tmp_file_path();
|
||||
|
||||
// Write into the temporary file
|
||||
std::ofstream out(tmp_file_path, std::ios::binary);
|
||||
DG_HOST_ASSERT(out.write(data.data(), data.size()));
|
||||
out.close();
|
||||
|
||||
// Atomically replace
|
||||
std::filesystem::rename(tmp_file_path, path);
|
||||
}
|
||||
|
||||
std::shared_ptr<KernelRuntime> build(const std::string& name, const std::string& code) const {
|
||||
const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code);
|
||||
const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature));
|
||||
|
||||
// Hit the runtime cache
|
||||
if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr)
|
||||
return runtime;
|
||||
|
||||
// Create the kernel directory
|
||||
make_dirs(dir_path);
|
||||
|
||||
// Compile into a temporary CUBIN
|
||||
const auto tmp_cubin_path = get_tmp_file_path();
|
||||
compile(code, dir_path, tmp_cubin_path);
|
||||
|
||||
// Replace into the cache directory
|
||||
make_dirs(dir_path);
|
||||
std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin");
|
||||
|
||||
// Put into the runtime cache
|
||||
const auto& runtime = kernel_runtime_cache->get(dir_path);
|
||||
DG_HOST_ASSERT(runtime != nullptr);
|
||||
return runtime;
|
||||
}
|
||||
|
||||
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0;
|
||||
};
|
||||
|
||||
class NVCCCompiler final: public Compiler {
|
||||
std::filesystem::path nvcc_path;
|
||||
|
||||
std::pair<int, int> get_nvcc_version() const {
|
||||
DG_HOST_ASSERT(std::filesystem::exists(nvcc_path));
|
||||
|
||||
// Call the version command
|
||||
const auto& command = std::string(nvcc_path) + " --version";
|
||||
const auto& [return_code, output] = call_external_command(command);
|
||||
DG_HOST_ASSERT(return_code == 0);
|
||||
|
||||
// The version should be at least 12.3, for the best performance with 12.9
|
||||
int major, minor;
|
||||
std::smatch match;
|
||||
DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))")));
|
||||
std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor);
|
||||
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3");
|
||||
if (major < 12 or (major == 12 and minor < 9))
|
||||
printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance");
|
||||
return {major, minor};
|
||||
}
|
||||
|
||||
public:
|
||||
NVCCCompiler(const std::filesystem::path& library_root_path,
|
||||
const std::filesystem::path& cuda_home_path_by_torch):
|
||||
Compiler(library_root_path) {
|
||||
// Override the compiler signature
|
||||
nvcc_path = cuda_home_path_by_torch / "bin" / "nvcc";
|
||||
if (const auto& env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
|
||||
nvcc_path = env_nvcc_path;
|
||||
const auto& [nvcc_major, nvcc_minor] = get_nvcc_version();
|
||||
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
|
||||
|
||||
// The override the compiler flags
|
||||
flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a "
|
||||
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
|
||||
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
|
||||
flags, library_include_path.c_str(), device_runtime->get_arch());
|
||||
}
|
||||
|
||||
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
|
||||
// Write the code into the cache directory
|
||||
const auto& code_path = dir_path / "kernel.cu";
|
||||
put(code_path, code);
|
||||
|
||||
// Compile
|
||||
const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
|
||||
printf("Running NVCC command: %s", command.c_str());
|
||||
const auto& [return_code, output] = call_external_command(command);
|
||||
if (return_code != 0) {
|
||||
printf("NVCC compilation failed: %s", output.c_str());
|
||||
DG_HOST_ASSERT(false and "NVCC compilation failed");
|
||||
}
|
||||
|
||||
// Print PTXAS log
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
|
||||
printf("%s", output.c_str());
|
||||
}
|
||||
};
|
||||
|
||||
static std::shared_ptr<Compiler> compiler = nullptr;
|
||||
|
||||
} // namespace deep_gemm
|
||||
50
csrc/jit/device_runtime.hpp
Normal file
50
csrc/jit/device_runtime.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class DeviceRuntime {
|
||||
int num_sms = 0;
|
||||
std::shared_ptr<cudaDeviceProp> cached_prop;
|
||||
|
||||
public:
|
||||
explicit DeviceRuntime() = default;
|
||||
|
||||
std::shared_ptr<cudaDeviceProp> get_prop() {
|
||||
if (cached_prop == nullptr)
|
||||
cached_prop = std::make_shared<cudaDeviceProp>(*at::cuda::getCurrentDeviceProperties());
|
||||
return cached_prop;
|
||||
}
|
||||
|
||||
std::pair<int, int> get_arch_pair() {
|
||||
const auto prop = get_prop();
|
||||
return {prop->major, prop->minor};
|
||||
}
|
||||
|
||||
int get_arch() {
|
||||
const auto& [major, minor] = get_arch_pair();
|
||||
return major * 10 + minor;
|
||||
}
|
||||
|
||||
int get_arch_major() {
|
||||
return get_arch_pair().first;
|
||||
}
|
||||
|
||||
void set_num_sms(const int& new_num_sms) {
|
||||
DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount);
|
||||
num_sms = new_num_sms;
|
||||
}
|
||||
|
||||
int get_num_sms() {
|
||||
if (num_sms == 0)
|
||||
num_sms = get_prop()->multiProcessorCount;
|
||||
return num_sms;
|
||||
}
|
||||
};
|
||||
|
||||
static auto device_runtime = std::make_shared<DeviceRuntime>();
|
||||
|
||||
} // namespace deep_gemm
|
||||
139
csrc/jit/kernel_runtime.hpp
Normal file
139
csrc/jit/kernel_runtime.hpp
Normal file
@@ -0,0 +1,139 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
#include "../utils/format.hpp"
|
||||
#include "../utils/system.hpp"
|
||||
#include "device_runtime.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct LaunchArgs {
|
||||
std::pair<int, int> grid_dim;
|
||||
int num_threads;
|
||||
int smem_size;
|
||||
int cluster_dim;
|
||||
|
||||
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
|
||||
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
|
||||
|
||||
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
|
||||
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasLaunchArgs = requires (const T& t) {
|
||||
{ t.launch_args } -> std::convertible_to<decltype(t.launch_args)>;
|
||||
};
|
||||
|
||||
class KernelRuntime final {
|
||||
public:
|
||||
static std::filesystem::path cuda_home;
|
||||
|
||||
cudaLibrary_t library;
|
||||
cudaKernel_t kernel;
|
||||
|
||||
explicit KernelRuntime(const std::filesystem::path& dir_path) {
|
||||
// NOLINT(*-pro-type-member-init)
|
||||
const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump";
|
||||
const auto& cubin_path = dir_path / "kernel.cubin";
|
||||
if (get_env<int>("DG_JIT_DEBUG"))
|
||||
printf("Loading CUBIN: %s\n", cubin_path.c_str());
|
||||
|
||||
// Find the only symbol
|
||||
// TODO: use kernel enumeration for newer drivers
|
||||
const std::vector<std::string> illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"};
|
||||
const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str()));
|
||||
DG_HOST_ASSERT(exit_code == 0);
|
||||
std::istringstream iss(symbols);
|
||||
std::vector<std::string> symbol_names;
|
||||
for (std::string line; std::getline(iss, line); ) {
|
||||
if (line.find("STT_FUNC") == 0 and std::ranges::none_of(illegal_names, [&](const auto& name) { return line.find(name) != std::string::npos; })) {
|
||||
const auto& last_space = line.rfind(' ');
|
||||
symbol_names.push_back(line.substr(last_space + 1));
|
||||
}
|
||||
}
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Symbol names: ");
|
||||
for (const auto& symbol: symbol_names)
|
||||
printf("%s, ", symbol.c_str());
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// Load from the library
|
||||
DG_HOST_ASSERT(symbol_names.size() == 1);
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, symbol_names[0].c_str()));
|
||||
}
|
||||
|
||||
static void set_cuda_home(const std::string& cuda_home_path_by_torch) {
|
||||
cuda_home = cuda_home_path_by_torch;
|
||||
}
|
||||
|
||||
static bool check_validity(const std::filesystem::path& dir_path) {
|
||||
return std::filesystem::exists(dir_path / "kernel.cu") and
|
||||
std::filesystem::exists(dir_path / "kernel.cubin");
|
||||
}
|
||||
|
||||
~KernelRuntime() noexcept(false) {
|
||||
const auto& error = cudaLibraryUnload(library);
|
||||
DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading);
|
||||
}
|
||||
};
|
||||
|
||||
// Declare after defining
|
||||
decltype(KernelRuntime::cuda_home) KernelRuntime::cuda_home;
|
||||
|
||||
template <typename Derived>
|
||||
class LaunchRuntime {
|
||||
public:
|
||||
template <typename Args> requires HasLaunchArgs<Args>
|
||||
static std::string generate(const Args& args) {
|
||||
const auto& code = Derived::generate_impl(args);
|
||||
if (get_env<int>("DG_JIT_DEBUG", 0))
|
||||
printf("Generated kernel code: %s\n", code.c_str());
|
||||
return code;
|
||||
}
|
||||
|
||||
template <typename Args> requires HasLaunchArgs<Args>
|
||||
static void launch(const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) {
|
||||
const auto& kernel = kernel_runtime->kernel;
|
||||
const auto& stream = at::cuda::getCurrentCUDAStream();
|
||||
const LaunchArgs& launch_args = args.launch_args;
|
||||
|
||||
// Set dynamic shared memory size
|
||||
if (launch_args.smem_size > 0)
|
||||
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, launch_args.smem_size));
|
||||
|
||||
// Launch config
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = {static_cast<unsigned>(launch_args.grid_dim.first),
|
||||
static_cast<unsigned>(launch_args.grid_dim.second),
|
||||
1};
|
||||
config.blockDim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
|
||||
config.dynamicSmemBytes = launch_args.smem_size;
|
||||
config.stream = stream;
|
||||
config.numAttrs = 0;
|
||||
|
||||
// Clusters
|
||||
cudaLaunchAttribute attr;
|
||||
if (launch_args.cluster_dim > 1) {
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {static_cast<unsigned>(launch_args.cluster_dim), 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
}
|
||||
|
||||
// Launch in the derived class
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n",
|
||||
launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads,
|
||||
launch_args.smem_size, launch_args.cluster_dim, stream.id());
|
||||
}
|
||||
Derived::launch_impl(kernel, config, args);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
298
csrc/jit_kernels/heuristics/common.hpp
Normal file
298
csrc/jit_kernels/heuristics/common.hpp
Normal file
@@ -0,0 +1,298 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct MulticastConfig {
|
||||
int num_multicast;
|
||||
bool is_multicast_on_a;
|
||||
|
||||
MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a):
|
||||
num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) {
|
||||
DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2);
|
||||
}
|
||||
};
|
||||
|
||||
struct SharedMemoryConfig {
|
||||
int smem_size;
|
||||
int swizzle_a_mode;
|
||||
int swizzle_b_mode;
|
||||
int swizzle_cd_mode;
|
||||
};
|
||||
|
||||
struct ThreadConfig {
|
||||
int num_threads;
|
||||
|
||||
// SM90
|
||||
int num_tma_threads;
|
||||
int num_math_threads;
|
||||
|
||||
// SM100
|
||||
int num_non_epilogue_threads;
|
||||
int num_epilogue_threads;
|
||||
|
||||
static ThreadConfig sm90(const int& num_tma_threads,
|
||||
const int& num_math_threads) {
|
||||
auto config = ThreadConfig();
|
||||
config.num_threads = num_tma_threads + num_math_threads;
|
||||
config.num_tma_threads = num_tma_threads;
|
||||
config.num_math_threads = num_math_threads;
|
||||
return config;
|
||||
}
|
||||
|
||||
static ThreadConfig sm100(const int& num_non_epilogue_threads,
|
||||
const int& num_epilogue_threads) {
|
||||
auto config = ThreadConfig();
|
||||
config.num_threads = num_non_epilogue_threads + num_epilogue_threads;
|
||||
config.num_non_epilogue_threads = num_non_epilogue_threads;
|
||||
config.num_epilogue_threads = num_epilogue_threads;
|
||||
return config;
|
||||
}
|
||||
};
|
||||
|
||||
struct GemmConfig {
|
||||
// Templated configs
|
||||
GemmType gemm_type;
|
||||
KernelType kernel_type;
|
||||
at::ScalarType ab_dtype, cd_dtype;
|
||||
cute::UMMA::Major major_a;
|
||||
cute::UMMA::Major major_b;
|
||||
bool with_accumulation;
|
||||
int block_m, block_n, block_k;
|
||||
int num_stages, num_last_stages;
|
||||
|
||||
// Runtime configs
|
||||
int num_sms;
|
||||
|
||||
// Structured configs
|
||||
MulticastConfig multicast_config;
|
||||
SharedMemoryConfig smem_config;
|
||||
ThreadConfig thread_config;
|
||||
};
|
||||
|
||||
static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
|
||||
const int& num_multicast, const int& num_sms,
|
||||
const bool& require_divisible) {
|
||||
const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible;
|
||||
return divisible and num_sms % num_multicast == 0;
|
||||
}
|
||||
|
||||
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
|
||||
// `> 0` means interleaving
|
||||
// 16B actually means non-swizzling (but interleaving)
|
||||
for (const int& mode: {128, 64, 32, 16}) {
|
||||
if ((block_size * elem_size) % mode == 0)
|
||||
return mode;
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unreachable");
|
||||
}
|
||||
|
||||
template <typename ArchSpec>
|
||||
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
|
||||
const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& num_stages, const MulticastConfig& multicast_config) {
|
||||
const int& ab_elem_size = static_cast<int>(c10::elementSize(ab_dtype));
|
||||
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));
|
||||
|
||||
const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m);
|
||||
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
|
||||
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
|
||||
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
|
||||
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
|
||||
|
||||
// Different archs have different epilogue pipelines
|
||||
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
|
||||
|
||||
// A/B shared memory
|
||||
const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size;
|
||||
const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size;
|
||||
|
||||
// SF shared memory
|
||||
const auto& [smem_sfa_per_stage, smem_sfb_per_stage] =
|
||||
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype);
|
||||
const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k);
|
||||
|
||||
// M-barriers and tensor memory pointers
|
||||
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
|
||||
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
|
||||
|
||||
// Sum them up
|
||||
int smem_size = 0;
|
||||
smem_size += smem_cd;
|
||||
smem_size += num_stages * smem_a_per_stage;
|
||||
smem_size += num_stages * smem_b_per_stage;
|
||||
smem_size += num_stages * smem_sfa_per_stage;
|
||||
smem_size += num_stages * smem_sfb_per_stage;
|
||||
smem_size += smem_extra_sfb;
|
||||
smem_size += smem_barrier;
|
||||
smem_size += smem_tmem_ptr;
|
||||
|
||||
return SharedMemoryConfig {
|
||||
.smem_size = smem_size,
|
||||
.swizzle_a_mode = swizzle_a_mode,
|
||||
.swizzle_b_mode = swizzle_b_mode,
|
||||
.swizzle_cd_mode = swizzle_cd_mode,
|
||||
};
|
||||
}
|
||||
|
||||
template <typename ArchSpec>
|
||||
static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
|
||||
const int& m, const int& n, const int& k, const int& num_groups,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const bool& with_accumulation, const int& num_sms) {
|
||||
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
|
||||
|
||||
// Select M/N block sizes
|
||||
// TODO: support `% 16 == 8` block size on SM90
|
||||
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
|
||||
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
|
||||
std::vector<int> block_ns;
|
||||
for (int i = 16; i <= 256; i += 16)
|
||||
block_ns.push_back(i);
|
||||
|
||||
// K block size is selected in a fixed manner
|
||||
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
|
||||
|
||||
// Some util functions
|
||||
const auto& get_num_blocks = [=](const int& block_m, const int& block_n) {
|
||||
return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups;
|
||||
};
|
||||
const auto& get_num_waves = [=](const int& block_m, const int& block_n) {
|
||||
return ceil_div(get_num_blocks(block_m, block_n), num_sms);
|
||||
};
|
||||
const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) {
|
||||
const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms;
|
||||
return num_last_blocks == 0 ? num_sms : num_last_blocks;
|
||||
};
|
||||
|
||||
// Decide block sizes by waves
|
||||
int best_block_m = 0, best_block_n = 0;
|
||||
int best_num_waves = 0, best_last_util = 0;
|
||||
for (const auto& block_m: block_ms) {
|
||||
for (const auto& block_n: block_ns) {
|
||||
const int& num_waves = get_num_waves(block_m, block_n);
|
||||
const auto& last_util = get_last_wave_util(block_m, block_n);
|
||||
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n))
|
||||
continue;
|
||||
|
||||
bool success = false;
|
||||
if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) {
|
||||
success = true;
|
||||
} else if (num_waves == best_num_waves) {
|
||||
// Check last wave utilization
|
||||
success = last_util > best_last_util;
|
||||
if (last_util == best_last_util) {
|
||||
// Case 1: same `block_m`, smaller `block_n` (wasted)
|
||||
success |= block_m == best_block_m and block_n < best_block_n;
|
||||
// Case 2: same `block_n`, smaller `block_m` (wasted)
|
||||
success |= block_n == best_block_n and block_m < best_block_m;
|
||||
// Case 3: different for both `block_m` and `block_n`, larger `block_n` is better
|
||||
success |= block_m != best_block_m and block_n > best_block_n;
|
||||
}
|
||||
}
|
||||
|
||||
// Replace with the new config if successful
|
||||
if (success) {
|
||||
best_block_m = block_m, best_block_n = block_n;
|
||||
best_num_waves = num_waves, best_last_util = last_util;
|
||||
}
|
||||
}
|
||||
}
|
||||
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
|
||||
|
||||
// Decide the number of TMA multicasts and whether broadcast on A
|
||||
MulticastConfig best_multicast_config = {1, true};
|
||||
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
|
||||
gemm_type, m, n, best_block_m, best_block_n, num_sms);
|
||||
const bool is_legal[2] = {is_legal_on_a, is_legal_on_b};
|
||||
bool order[2] = {false, true};
|
||||
if (best_block_m > best_block_n)
|
||||
std::swap(order[0], order[1]);
|
||||
for (const bool& is_multicast_on_a: order) {
|
||||
if (m >= 512 and is_legal[static_cast<int>(is_multicast_on_a)]) {
|
||||
best_multicast_config = {2, is_multicast_on_a};
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Always pick the largest number of stage
|
||||
constexpr int smem_capacity = ArchSpec::smem_capacity;
|
||||
int best_num_stages = 0;
|
||||
SharedMemoryConfig best_smem_config;
|
||||
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
|
||||
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
|
||||
continue;
|
||||
|
||||
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
|
||||
m, n, k,
|
||||
best_block_m, best_block_n, block_k,
|
||||
major_a, major_b,
|
||||
ab_dtype, cd_dtype,
|
||||
num_stages, best_multicast_config);
|
||||
if (best_smem_config.smem_size <= smem_capacity) {
|
||||
best_num_stages = num_stages;
|
||||
break;
|
||||
}
|
||||
}
|
||||
DG_HOST_ASSERT(best_num_stages != 0);
|
||||
|
||||
// Recompute the minimal number of SMs required
|
||||
// NOTES: less L2 cache usage and less GPU frequency drop
|
||||
int num_min_sms = num_sms;
|
||||
if (ArchSpec::should_minimize_num_sms()) {
|
||||
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
|
||||
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
|
||||
DG_HOST_ASSERT(num_min_sms <= num_sms);
|
||||
}
|
||||
|
||||
const auto& config = GemmConfig {
|
||||
.gemm_type = gemm_type,
|
||||
.kernel_type = kernel_type,
|
||||
.ab_dtype = ab_dtype,
|
||||
.cd_dtype = cd_dtype,
|
||||
.major_a = major_a,
|
||||
.major_b = major_b,
|
||||
.with_accumulation = with_accumulation,
|
||||
.block_m = best_block_m,
|
||||
.block_n = best_block_n,
|
||||
.block_k = block_k,
|
||||
.num_stages = best_num_stages,
|
||||
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
|
||||
.num_sms = num_min_sms,
|
||||
.multicast_config = best_multicast_config,
|
||||
// ReSharper disable once CppLocalVariableMightNotBeInitialized
|
||||
.smem_config = best_smem_config,
|
||||
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
|
||||
};
|
||||
|
||||
// Print configs for the first time
|
||||
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
|
||||
auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b,
|
||||
ab_dtype, cd_dtype, with_accumulation, num_sms);
|
||||
static std::set<decltype(key)> printed;
|
||||
if (not printed.contains(key)) {
|
||||
printf("Gemm type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
|
||||
"A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, "
|
||||
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
|
||||
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
|
||||
"swizzle B: %d, swizzle CD: %d, threads: %d\n",
|
||||
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
|
||||
static_cast<int>(major_a), static_cast<int>(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype),
|
||||
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
|
||||
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
|
||||
static_cast<int>(best_multicast_config.is_multicast_on_a),
|
||||
best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode,
|
||||
best_smem_config.swizzle_cd_mode, config.thread_config.num_threads);
|
||||
printed.insert(key);
|
||||
}
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
144
csrc/jit_kernels/heuristics/sm100.hpp
Normal file
144
csrc/jit_kernels/heuristics/sm100.hpp
Normal file
@@ -0,0 +1,144 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm100_desc.hpp>
|
||||
// Reuse some types in the JIT modules
|
||||
#include <deep_gemm/common/types.hpp>
|
||||
|
||||
#include "common.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM100ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
|
||||
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
|
||||
}
|
||||
|
||||
static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) {
|
||||
return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast);
|
||||
}
|
||||
|
||||
static int get_cd_store_block_m(const int& block_m) {
|
||||
constexpr int layout_ad_m = 128;
|
||||
return std::min(block_m, layout_ad_m);
|
||||
}
|
||||
|
||||
static int get_cd_store_block_n(const int& block_n) {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
|
||||
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
|
||||
constexpr int num_utccp_aligned_elems = 128;
|
||||
DG_HOST_ASSERT(block_m % num_utccp_aligned_elems == 0);
|
||||
switch (ab_dtype) {
|
||||
case torch::kBFloat16: return {0, 0};
|
||||
case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
|
||||
default: DG_HOST_UNREACHABLE("Unknown dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n) {
|
||||
// Layout A/D does not support `block_m == 64` and `block_n % 16 != 0`
|
||||
if (block_m == 64 or block_n % 16 != 0)
|
||||
return false;
|
||||
|
||||
// Performance is lower with 1D1D and `block_m == 256`
|
||||
if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m != 128)
|
||||
return false;
|
||||
|
||||
// 1D2D kernels' maximum block N is 128
|
||||
// 1D2D kernels require more friendly block Ns
|
||||
if (kernel_type == KernelType::Kernel1D2D and (block_n > 128 or 128 % block_n != 0))
|
||||
return false;
|
||||
|
||||
// Check tensor memory validity
|
||||
int sf_block_m = 0, sf_block_n = 0;
|
||||
if (kernel_type == KernelType::Kernel1D1D) {
|
||||
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
|
||||
sf_block_m = sf_block_m_, sf_block_n = sf_block_n_;
|
||||
}
|
||||
if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512)
|
||||
return false;
|
||||
|
||||
// NOTES: when B is MN-major, we restrict `block_n` to multiples of 64,
|
||||
// since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA
|
||||
return major_b == cute::UMMA::Major::K or block_n % 64 == 0;
|
||||
}
|
||||
|
||||
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& num_stages,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool should_minimize_num_sms() {
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
// TODO: support other layouts
|
||||
return {
|
||||
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
|
||||
false,
|
||||
};
|
||||
}
|
||||
|
||||
static ThreadConfig get_thread_config(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n) {
|
||||
return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D1D ? 128 : block_m);
|
||||
}
|
||||
|
||||
static int get_smem_cd_size(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n,
|
||||
const int& swizzle_cd_mode,
|
||||
const at::ScalarType& cd_dtype) {
|
||||
constexpr static int layout_ad_m = 128;
|
||||
return (kernel_type == KernelType::Kernel1D1D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2;
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
|
||||
if (ab_dtype == torch::kBFloat16)
|
||||
return {0, 0};
|
||||
|
||||
int smem_sfa_per_stage = 0;
|
||||
int smem_sfb_per_stage = 0;
|
||||
if (kernel_type == KernelType::Kernel1D1D) {
|
||||
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
|
||||
smem_sfa_per_stage = sf_block_m * 4;
|
||||
smem_sfb_per_stage = sf_block_n * 4;
|
||||
} else {
|
||||
smem_sfa_per_stage = block_m * 4;
|
||||
smem_sfb_per_stage = 0;
|
||||
}
|
||||
return {smem_sfa_per_stage, smem_sfb_per_stage};
|
||||
}
|
||||
|
||||
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static int get_barrier_smem_size(const int& num_stages) {
|
||||
// TODO: remove SF barriers for BF16 GEMMs
|
||||
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
|
||||
// NOTES: 1D2D kernel will not use the with-SF full barriers
|
||||
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
|
||||
return num_stages * 8 * 3 + 2 * 8 * 2;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 4;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
115
csrc/jit_kernels/heuristics/sm90.hpp
Normal file
115
csrc/jit_kernels/heuristics/sm90.hpp
Normal file
@@ -0,0 +1,115 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm100_desc.hpp>
|
||||
// Reuse some types in the JIT modules
|
||||
#include <deep_gemm/common/types.hpp>
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
struct SM90ArchSpec {
|
||||
static constexpr int smem_capacity = 232448;
|
||||
|
||||
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
|
||||
return block_m;
|
||||
}
|
||||
|
||||
static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_m(const int& block_m) {
|
||||
return block_m;
|
||||
}
|
||||
|
||||
static int get_cd_store_block_n(const int& block_n) {
|
||||
return block_n;
|
||||
}
|
||||
|
||||
static bool is_block_size_legal(const KernelType& kernel_type,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n) {
|
||||
// FP32 output does not support `block_m == 256`
|
||||
if (cd_dtype == at::kFloat and block_m == 256)
|
||||
return false;
|
||||
|
||||
// Must be some fixed block N selections
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 or block_n != 152))
|
||||
return false;
|
||||
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 or block_n != 160))
|
||||
return false;
|
||||
|
||||
// Avoid bank conflicts for FP32 output
|
||||
if (cd_dtype == torch::kFloat and block_n % 16 == 0)
|
||||
return false;
|
||||
|
||||
// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
|
||||
return block_m <= 128 or block_n <= 128;
|
||||
}
|
||||
|
||||
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& num_stages,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
// Unrolling both stages and `num_former_iters` will cause large code size
|
||||
if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
|
||||
return num_stages <= 4;
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool should_minimize_num_sms() {
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
|
||||
const int& m, const int& n, const int& block_m, const int& block_n,
|
||||
const int& num_sms) {
|
||||
return {
|
||||
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
|
||||
is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked,
|
||||
};
|
||||
}
|
||||
|
||||
static ThreadConfig get_thread_config(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n) {
|
||||
return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128);
|
||||
}
|
||||
|
||||
static int get_smem_cd_size(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n,
|
||||
const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) {
|
||||
return block_m * block_n * static_cast<int>(c10::elementSize(cd_dtype));
|
||||
}
|
||||
|
||||
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n, const int& block_k,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
|
||||
if (ab_dtype == torch::kBFloat16)
|
||||
return {0, 0};
|
||||
|
||||
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
|
||||
int smem_sfb_per_stage = 0;
|
||||
// TODO: figure out here
|
||||
if (kernel_type == KernelType::Kernel1D1D)
|
||||
smem_sfb_per_stage = align(block_n * 4, block_k);
|
||||
return {smem_sfa_per_stage, smem_sfb_per_stage};
|
||||
}
|
||||
|
||||
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
|
||||
const int& block_m, const int& block_n, const int& block_k) {
|
||||
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
|
||||
return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
|
||||
}
|
||||
|
||||
static int get_barrier_smem_size(const int& num_stages) {
|
||||
// For 1D1D kernels, there is an extra barrier for accumulation
|
||||
return (num_stages + 1) * 8 * 2;
|
||||
}
|
||||
|
||||
static int get_tmem_ptr_smem_size() {
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
173
csrc/jit_kernels/impls/runtime_utils.hpp
Normal file
173
csrc/jit_kernels/impls/runtime_utils.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static std::pair<int, int> get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) {
|
||||
return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k);
|
||||
}
|
||||
|
||||
static int get_non_contiguous_dim(const cute::UMMA::Major& major) {
|
||||
return major == cute::UMMA::Major::K ? -2 : -1;
|
||||
}
|
||||
|
||||
static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) {
|
||||
for (const char& c: compiled_dims) {
|
||||
if (name == c)
|
||||
return dim;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static std::string to_string(const cute::UMMA::Major& major) {
|
||||
switch (major) {
|
||||
case cute::UMMA::Major::K: return "cute::UMMA::Major::K";
|
||||
case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN";
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unknown major");
|
||||
}
|
||||
|
||||
static std::string to_string(const GemmType& type) {
|
||||
switch (type) {
|
||||
case GemmType::Normal: return "GemmType::Normal";
|
||||
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
|
||||
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
|
||||
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unknown GEMM type");
|
||||
}
|
||||
|
||||
static std::string to_string(const at::ScalarType& dtype) {
|
||||
switch (dtype) {
|
||||
case torch::kInt: return "int";
|
||||
case torch::kFloat: return "float";
|
||||
case torch::kBFloat16: return "cutlass::bfloat16_t";
|
||||
default: DG_HOST_UNREACHABLE("Unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
|
||||
switch (dtype) {
|
||||
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
|
||||
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
|
||||
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
||||
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
||||
default: DG_HOST_UNREACHABLE("Unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
|
||||
switch (mode) {
|
||||
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
default: DG_HOST_UNREACHABLE("Unsupported swizzling mode");
|
||||
}
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
|
||||
int gmem_inner_dim, int gmem_outer_dim,
|
||||
int smem_inner_dim, int smem_outer_dim,
|
||||
const int& gmem_outer_stride,
|
||||
const int& swizzle_mode) {
|
||||
const auto& elem_size = static_cast<int>(t.element_size());
|
||||
if (swizzle_mode != 0)
|
||||
smem_inner_dim = swizzle_mode / elem_size;
|
||||
|
||||
CUtensorMap tensor_map;
|
||||
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
|
||||
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
|
||||
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
|
||||
const cuuint32_t elem_strides[2] = {1, 1};
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
|
||||
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
|
||||
gmem_outer_stride, swizzle_mode, elem_size);
|
||||
}
|
||||
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
|
||||
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
|
||||
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
|
||||
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
|
||||
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensor_map;
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
|
||||
const torch::Tensor& t,
|
||||
const int& shape_m, const int& shape_k,
|
||||
const int& block_m, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
if (num_groups > 1)
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
|
||||
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
|
||||
return make_tma_2d_desc(t,
|
||||
gmem_inner_dim, gmem_outer_dim,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
|
||||
const torch::Tensor& t,
|
||||
const int& shape_n, const int& shape_k,
|
||||
const int& block_n, const int& block_k,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
|
||||
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
|
||||
|
||||
// `num_groups` is always applied into the outer dimensions
|
||||
return make_tma_2d_desc(t,
|
||||
gmem_inner_dim, gmem_outer_dim * num_groups,
|
||||
smem_inner_dim, smem_outer_dim,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
|
||||
const int& shape_m, const int& shape_n,
|
||||
const int& block_m, const int& block_n,
|
||||
const int& outer_stride,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
|
||||
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
|
||||
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
|
||||
return make_tma_2d_desc(t,
|
||||
shape_n, shape_m * num_groups,
|
||||
block_n, block_m,
|
||||
outer_stride,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
|
||||
const torch::Tensor& t,
|
||||
int shape_mn, int shape_k,
|
||||
const int& block_mn, const int& block_k,
|
||||
const int& num_groups,
|
||||
const int& swizzle_mode) {
|
||||
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
|
||||
|
||||
// TODO: maybe swizzle SF as well
|
||||
DG_HOST_ASSERT(swizzle_mode == 0);
|
||||
|
||||
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
|
||||
return make_tma_2d_desc(t,
|
||||
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
|
||||
block_mn, 1,
|
||||
shape_mn,
|
||||
swizzle_mode);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
351
csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Normal file
351
csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp
Normal file
@@ -0,0 +1,351 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void* grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
CUtensorMap tensor_map_sfb;
|
||||
CUtensorMap tensor_map_c;
|
||||
CUtensorMap tensor_map_d;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type),
|
||||
args.gemm_config.with_accumulation,
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.grouped_layout, args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_sfa, args.tensor_map_sfb,
|
||||
args.tensor_map_c, args.tensor_map_d));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D1D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& cd = c.value_or(d);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(cd.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
if (c->data_ptr() == d.data_ptr()) {
|
||||
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
} else {
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
d.copy_(c.value());
|
||||
}
|
||||
}
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_c,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_d,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_d,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n,
|
||||
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
|
||||
|
||||
int sum_k = 0, sum_sf_k = 0;
|
||||
for (const auto& k: ks) {
|
||||
sum_k += k, sum_sf_k += ceil_div(k, 512);
|
||||
DG_HOST_ASSERT(k % 128 == 0);
|
||||
}
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
|
||||
// Get config using max K for better performance
|
||||
const auto& max_k = *std::ranges::max_element(ks);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Create tensor descriptors
|
||||
const auto& cd = c.value_or(d);
|
||||
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(0)), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(0)), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(1)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(cd.stride(1)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr());
|
||||
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
|
||||
}
|
||||
|
||||
// Launch kernel
|
||||
const SM100FP8Gemm1D1DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = sum_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.grouped_layout = ks_tensor.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
.tensor_map_sfb = tensor_map_sfb,
|
||||
.tensor_map_c = tensor_map_c,
|
||||
.tensor_map_d = tensor_map_d
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
|
||||
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
242
csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp
Normal file
242
csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp
Normal file
@@ -0,0 +1,242 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../heuristics/sm100.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM100FP8Gemm1D2DRuntime final: public LaunchRuntime<SM100FP8Gemm1D2DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *sfb, *grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d2d_impl<
|
||||
{}, {},
|
||||
{}, {}, {},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.num_groups,
|
||||
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.sfb, args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_d, args.tensor_map_sfa));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value());
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D2D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code);
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code);
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM100ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM100ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch
|
||||
const SM100FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code);
|
||||
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
255
csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Normal file
255
csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp
Normal file
@@ -0,0 +1,255 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/compiler.hpp"
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../heuristics/sm90.hpp"
|
||||
#include "runtime_utils.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime> {
|
||||
public:
|
||||
struct Args {
|
||||
int m, n, k, num_groups;
|
||||
const std::string& compiled_dims;
|
||||
|
||||
GemmConfig gemm_config;
|
||||
LaunchArgs launch_args;
|
||||
|
||||
void *sfb, *grouped_layout;
|
||||
CUtensorMap tensor_map_a;
|
||||
CUtensorMap tensor_map_b;
|
||||
CUtensorMap tensor_map_d;
|
||||
CUtensorMap tensor_map_sfa;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {}, {},
|
||||
{},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
// TODO: add CD dtype
|
||||
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
|
||||
args.num_groups,
|
||||
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
|
||||
args.gemm_config.smem_config.swizzle_cd_mode,
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.sfb, args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_d, args.tensor_map_sfa));
|
||||
}
|
||||
};
|
||||
|
||||
static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const torch::Tensor& d,
|
||||
const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::Normal, KernelType::Kernel1D2D,
|
||||
m, n, k, 1, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = 1,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = nullptr,
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& aligned_k = align(k, 128);
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
|
||||
m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), 1,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = m_indices.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
|
||||
const torch::Tensor& b, const torch::Tensor& sfb,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& num_groups, const int& m, const int& n, const int& k,
|
||||
const int& expected_m,
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const std::string& compiled_dims) {
|
||||
const auto& aligned_k = align(k, 128);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
|
||||
const auto& config = get_best_config<SM90ArchSpec>(
|
||||
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
|
||||
expected_m, n, k, num_groups, major_a, major_b,
|
||||
torch::kFloat8_e4m3fn, d.scalar_type(), false,
|
||||
device_runtime->get_num_sms());
|
||||
|
||||
// Requires no TMA splits
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
|
||||
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
|
||||
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
|
||||
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
|
||||
config.block_k,
|
||||
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
|
||||
config.smem_config.swizzle_a_mode);
|
||||
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
|
||||
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
|
||||
config.block_k,
|
||||
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
|
||||
config.smem_config.swizzle_b_mode);
|
||||
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
|
||||
SM90ArchSpec::get_cd_store_block_m(config.block_m),
|
||||
SM90ArchSpec::get_cd_store_block_n(config.block_n),
|
||||
static_cast<int>(d.stride(-2)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
|
||||
// Launch
|
||||
const SM90FP8Gemm1D2DRuntime::Args& args = {
|
||||
.m = m, .n = n, .k = aligned_k,
|
||||
.num_groups = num_groups,
|
||||
.compiled_dims = compiled_dims,
|
||||
.gemm_config = config,
|
||||
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
|
||||
config.smem_config.smem_size,
|
||||
config.multicast_config.num_multicast),
|
||||
.sfb = sfb.data_ptr(),
|
||||
.grouped_layout = masked_m.data_ptr(),
|
||||
.tensor_map_a = tensor_map_a,
|
||||
.tensor_map_b = tensor_map_b,
|
||||
.tensor_map_d = tensor_map_d,
|
||||
.tensor_map_sfa = tensor_map_sfa,
|
||||
};
|
||||
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
|
||||
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
|
||||
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
199
csrc/jit_kernels/impls/smxx_layout.hpp
Normal file
199
csrc/jit_kernels/impls/smxx_layout.hpp
Normal file
@@ -0,0 +1,199 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "../../jit/kernel_runtime.hpp"
|
||||
#include "../../utils/exception.hpp"
|
||||
#include "../../utils/format.hpp"
|
||||
#include "../../utils/math.hpp"
|
||||
#include "../../utils/layout.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
int mn, sf_k;
|
||||
int block_mn;
|
||||
void *sf, *out;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&transpose_and_pack_fp32_into_ue8m0<
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast<uint32_t>(args.mn)));
|
||||
}
|
||||
};
|
||||
|
||||
class PackFP32IntoUE8M0Runtime final: public LaunchRuntime<PackFP32IntoUE8M0Runtime> {
|
||||
public:
|
||||
struct Args {
|
||||
int num_groups, mn, sf_k, packed_sf_k;
|
||||
int block_mn, block_packed_sf_k;
|
||||
void *sf, *out, *ks;
|
||||
|
||||
LaunchArgs launch_args;
|
||||
};
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
|
||||
static void __instantiate_kernel() {{
|
||||
auto ptr = reinterpret_cast<void*>(&pack_fp32_into_ue8m0<
|
||||
{}, {}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k));
|
||||
}
|
||||
};
|
||||
|
||||
static std::tuple<int, int, int, int, int, torch::Tensor> preprocess_sf(const torch::Tensor& sf) {
|
||||
// NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
const auto& dim = sf.dim();
|
||||
DG_HOST_ASSERT(dim == 2 or dim == 3);
|
||||
DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat);
|
||||
const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
|
||||
|
||||
const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
|
||||
const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
|
||||
return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf};
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
|
||||
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
|
||||
|
||||
// The last kernel already gives a column-major TMA aligned layout
|
||||
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
|
||||
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
|
||||
|
||||
// Normal layout requires transposing
|
||||
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
|
||||
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
|
||||
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
|
||||
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
|
||||
const auto& packed_sf_k = ceil_div(sf_k, 4);
|
||||
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
|
||||
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
|
||||
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
|
||||
DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0);
|
||||
|
||||
// Launch the kernel
|
||||
if (batched_sf.is_contiguous()) {
|
||||
constexpr int block_mn = 48;
|
||||
constexpr int num_threads = 512;
|
||||
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.block_mn = block_mn,
|
||||
.sf = batched_sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4)
|
||||
};
|
||||
|
||||
const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
|
||||
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
} else {
|
||||
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
|
||||
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
|
||||
|
||||
constexpr int block_mn = 128;
|
||||
constexpr int block_packed_sf_k = 16;
|
||||
constexpr int num_threads = 512;
|
||||
const PackFP32IntoUE8M0Runtime::Args& args = {
|
||||
.num_groups = 1,
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.packed_sf_k = packed_sf_k,
|
||||
.block_mn = block_mn,
|
||||
.block_packed_sf_k = block_packed_sf_k,
|
||||
.sf = batched_sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.ks = nullptr,
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
|
||||
};
|
||||
|
||||
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
|
||||
PackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
}
|
||||
return (dim == 2) ? out.squeeze(0) : out;
|
||||
}
|
||||
|
||||
static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::vector<int>& ks) {
|
||||
const auto& [sf_k, mn] = get_shape<2>(sf);
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
|
||||
int ref_sf_k = 0, packed_sf_k = 0;
|
||||
for (const auto& k: ks)
|
||||
ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512);
|
||||
DG_HOST_ASSERT(sf.is_contiguous());
|
||||
DG_HOST_ASSERT(ref_sf_k == sf_k);
|
||||
DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0);
|
||||
|
||||
const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
|
||||
|
||||
constexpr int block_mn = 128;
|
||||
constexpr int block_packed_sf_k = 16;
|
||||
constexpr int num_threads = 512;
|
||||
const PackFP32IntoUE8M0Runtime::Args& args = {
|
||||
.num_groups = num_groups,
|
||||
.mn = mn,
|
||||
.sf_k = sf_k,
|
||||
.packed_sf_k = packed_sf_k,
|
||||
.block_mn = block_mn,
|
||||
.block_packed_sf_k = block_packed_sf_k,
|
||||
.sf = sf.data_ptr(),
|
||||
.out = out.data_ptr(),
|
||||
.ks = ks_tensor.data_ptr(),
|
||||
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
|
||||
};
|
||||
|
||||
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
|
||||
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
|
||||
PackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
402
csrc/python_api.cpp
Normal file
402
csrc/python_api.cpp
Normal file
@@ -0,0 +1,402 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "jit/compiler.hpp"
|
||||
#include "jit/device_runtime.hpp"
|
||||
#include "utils/layout.hpp"
|
||||
|
||||
#include "jit_kernels/impls/smxx_layout.hpp"
|
||||
#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME deep_gemm_cpp
|
||||
#endif
|
||||
|
||||
namespace deep_gemm {
|
||||
torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::optional<int>& num_groups,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
const auto& gran_k = std::get<2>(recipe);
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// Pre-transform checks
|
||||
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
|
||||
|
||||
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return get_mn_major_tma_aligned_tensor(sf);
|
||||
|
||||
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
|
||||
}
|
||||
|
||||
// (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
|
||||
|
||||
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
|
||||
DG_HOST_ASSERT(not disable_ue8m0_cast);
|
||||
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
|
||||
}
|
||||
|
||||
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
|
||||
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
|
||||
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown SF transformation");
|
||||
}
|
||||
|
||||
torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::tuple<int, int, int>& recipe) {
|
||||
DG_HOST_ASSERT(sf.dim() == 2);
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
|
||||
// FP32 on SM90
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
// FP32 on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
|
||||
|
||||
// INT on SM100
|
||||
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
|
||||
DG_HOST_UNREACHABLE("Unimplemented");
|
||||
|
||||
DG_HOST_UNREACHABLE("Unknown cases");
|
||||
}
|
||||
|
||||
void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [N, K].T`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
if (fp8_requires_k_major()) {
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
}
|
||||
|
||||
// C/D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m , k ] = get_shape<2>(a.first);
|
||||
const auto& [n , k_] = get_shape<2>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
|
||||
|
||||
// Check C as well
|
||||
if (c.has_value()) {
|
||||
check_major_type_cd(c.value());
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
}
|
||||
|
||||
// Do nothing if the problem is empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, std::nullopt, recipe.value(), false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
|
||||
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
|
||||
d, c, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
|
||||
if (fp8_requires_k_major())
|
||||
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(m_indices.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [m, k] = get_shape<2>(a.first);
|
||||
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [m_, n_] = get_shape<2>(d);
|
||||
const auto& m__ = static_cast<int>(m_indices.numel());
|
||||
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Do nothing if empty
|
||||
if (m == 0)
|
||||
return;
|
||||
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& m_indices,
|
||||
const std::optional<std::tuple<int, int, int>>& recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
|
||||
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
|
||||
}
|
||||
|
||||
void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const torch::Tensor& masked_m,
|
||||
const int& expected_m,
|
||||
std::optional<std::tuple<int, int, int>> recipe,
|
||||
const std::string& compiled_dims,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
// Shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
const auto& major_a = get_major_type_ab(a.first);
|
||||
const auto& major_b = get_major_type_ab(b.first);
|
||||
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
|
||||
DG_HOST_ASSERT(masked_m.is_contiguous());
|
||||
|
||||
// Type and shape checks
|
||||
const auto& [num_groups, m, k] = get_shape<3>(a.first);
|
||||
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
|
||||
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
|
||||
const auto& num_groups___ = static_cast<int>(masked_m.numel());
|
||||
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
|
||||
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
|
||||
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
|
||||
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
|
||||
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
|
||||
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
|
||||
|
||||
// D must be N-major
|
||||
check_major_type_cd(d);
|
||||
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, num_groups, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm90_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
|
||||
sm100_fp8_m_grouped_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported kernel or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
const std::pair<torch::Tensor, torch::Tensor>& b,
|
||||
const torch::Tensor& d,
|
||||
const std::vector<int>& ks,
|
||||
const torch::Tensor& ks_tensor,
|
||||
const std::optional<torch::Tensor>& c,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::string& compiled_dims) {
|
||||
// Must be 1D1D kernel
|
||||
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
|
||||
|
||||
// Contiguity checks
|
||||
DG_HOST_ASSERT(a.first.is_contiguous());
|
||||
DG_HOST_ASSERT(b.first.is_contiguous());
|
||||
DG_HOST_ASSERT(d.is_contiguous());
|
||||
if (c.has_value()) {
|
||||
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
|
||||
DG_HOST_ASSERT(c.value().is_contiguous());
|
||||
}
|
||||
|
||||
// Do nothing if empty
|
||||
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
|
||||
return;
|
||||
|
||||
// Transform SF with padding
|
||||
const auto& [_, m] = get_shape<2>(a.first);
|
||||
const auto& [__, n] = get_shape<2>(b.first);
|
||||
const auto& sfa = transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
|
||||
const auto& sfb = transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 10) {
|
||||
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
|
||||
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
// ReSharper disable once CppParameterMayBeConstPtrOrRef
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
using namespace deep_gemm;
|
||||
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
// Runtime
|
||||
m.def("get_num_sms", [&]() {
|
||||
return device_runtime->get_num_sms();
|
||||
});
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
});
|
||||
|
||||
// JIT
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) {
|
||||
DG_HOST_ASSERT(get_env("DG_JIT_USE_NVRTC", 0) == 0 and "Currently only support NVCC");
|
||||
compiler = std::make_shared<NVCCCompiler>(library_root_path, cuda_home_path_by_torch);
|
||||
KernelRuntime::set_cuda_home(cuda_home_path_by_torch);
|
||||
});
|
||||
|
||||
// Stable kernel APIs with automatic arch/layout dispatch
|
||||
m.def("fp8_gemm_nt", &fp8_gemm_nt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_nn", &fp8_gemm_nn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tn", &fp8_gemm_tn,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_gemm_tt", &fp8_gemm_tt,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"),
|
||||
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "mn",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
|
||||
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("fp8_m_grouped_gemm_nt_masked", &fp8_m_grouped_gemm_nt_masked,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
|
||||
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
|
||||
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
|
||||
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
|
||||
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout);
|
||||
|
||||
// Raw kernels or functions
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
|
||||
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
|
||||
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
|
||||
}
|
||||
58
csrc/utils/exception.hpp
Normal file
58
csrc/utils/exception.hpp
Normal file
@@ -0,0 +1,58 @@
|
||||
#pragma once
|
||||
|
||||
#include <exception>
|
||||
#include <string>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class DGException final : public std::exception {
|
||||
std::string message = {};
|
||||
|
||||
public:
|
||||
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
|
||||
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
|
||||
}
|
||||
|
||||
const char *what() const noexcept override {
|
||||
return message.c_str();
|
||||
}
|
||||
};
|
||||
|
||||
#ifndef DG_STATIC_ASSERT
|
||||
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_ASSERT
|
||||
#define DG_HOST_ASSERT(cond) \
|
||||
do { \
|
||||
if (not (cond)) { \
|
||||
throw DGException("Assertion", __FILE__, __LINE__, #cond); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_HOST_UNREACHABLE
|
||||
#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason))
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUDA_DRIVER_CHECK
|
||||
#define DG_CUDA_DRIVER_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != CUDA_SUCCESS) { \
|
||||
throw DGException("CUDA driver", __FILE__, __LINE__, ""); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUDA_RUNTIME_CHECK
|
||||
#define DG_CUDA_RUNTIME_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != cudaSuccess) { \
|
||||
throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast<int>(e))); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
} // namespace deep_gemm
|
||||
6
csrc/utils/format.hpp
Normal file
6
csrc/utils/format.hpp
Normal file
@@ -0,0 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
// Just a wrapper for the `fmt` headers
|
||||
#define FMT_HEADER_ONLY
|
||||
#include <fmt/base.h>
|
||||
#include <fmt/format.h>
|
||||
35
csrc/utils/hash.hpp
Normal file
35
csrc/utils/hash.hpp
Normal file
@@ -0,0 +1,35 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
static uint64_t fnv1a(const std::string& data, const uint64_t& seed) {
|
||||
uint64_t h = seed;
|
||||
const uint64_t& prime = 0x100000001b3ull;
|
||||
for (const char& c: data) {
|
||||
h ^= static_cast<uint8_t>(c);
|
||||
h *= prime;
|
||||
}
|
||||
return h;
|
||||
}
|
||||
|
||||
static std::string get_hex_digest(const std::string& data) {
|
||||
const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull);
|
||||
const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull);
|
||||
|
||||
// Split-mix 64
|
||||
const auto& split_mix = [](uint64_t z) {
|
||||
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull;
|
||||
z = (z ^ (z >> 27)) * 0x94d049bb133111ebull;
|
||||
return z ^ (z >> 31);
|
||||
};
|
||||
|
||||
std::ostringstream oss;
|
||||
oss << std::hex << std::setfill('0')
|
||||
<< std::setw(16) << split_mix(state_0)
|
||||
<< std::setw(16) << split_mix(state_1);
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
100
csrc/utils/layout.hpp
Normal file
100
csrc/utils/layout.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
#pragma once
|
||||
|
||||
#include <cute/arch/mma_sm100_umma.hpp>
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "math.hpp"
|
||||
#include "exception.hpp"
|
||||
#include "../jit/device_runtime.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// Major-ness stuffs
|
||||
static void major_check(const torch::Tensor& t) {
|
||||
const auto dim = t.dim();
|
||||
DG_HOST_ASSERT(dim == 2 or dim == 3);
|
||||
if (dim == 3)
|
||||
DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1));
|
||||
DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1);
|
||||
}
|
||||
|
||||
static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) {
|
||||
major_check(t);
|
||||
return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
|
||||
}
|
||||
|
||||
static void check_major_type_cd(const torch::Tensor& t) {
|
||||
// NOTES: the library only supports row-major output layouts
|
||||
major_check(t);
|
||||
DG_HOST_ASSERT(t.stride(-1) == 1);
|
||||
}
|
||||
|
||||
static bool fp8_requires_k_major() {
|
||||
return device_runtime->get_arch_major() == 9;
|
||||
}
|
||||
|
||||
// Tensor utils
|
||||
template <int N>
|
||||
static auto get_shape(const torch::Tensor& t) {
|
||||
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
|
||||
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
|
||||
}(std::make_index_sequence<N>());
|
||||
}
|
||||
|
||||
// Recipe
|
||||
static std::tuple<int, int, int>
|
||||
get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) {
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
if (arch_major == 9) {
|
||||
DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat);
|
||||
return {1, 128, 128};
|
||||
} else if (arch_major == 10) {
|
||||
DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt);
|
||||
return sfb_dtype == torch::kFloat ?
|
||||
std::make_tuple(1, 128, 128): // Legacy format or 1D2D kernels
|
||||
std::make_tuple(1, 1, 128); // 1D1D kernels
|
||||
}
|
||||
DG_HOST_UNREACHABLE("Unknown recipe");
|
||||
}
|
||||
|
||||
// SF layouts
|
||||
static torch::Tensor check_sf_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const int& gran_mn, const int& gran_k,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& tma_stride_check = false,
|
||||
const bool& contiguous_check = false,
|
||||
const std::optional<torch::ScalarType>& type_check = std::nullopt) {
|
||||
// Type check
|
||||
if (type_check.has_value())
|
||||
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
|
||||
|
||||
// Always do shape checks
|
||||
const auto& sf_dtype = sf.scalar_type();
|
||||
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
|
||||
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
|
||||
if (num_groups.has_value())
|
||||
DG_HOST_ASSERT(sf.size(-3) == num_groups.value());
|
||||
DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn));
|
||||
DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)));
|
||||
|
||||
// TMA stride checks: TMA aligned and MN-major
|
||||
if (tma_stride_check) {
|
||||
if (num_groups.has_value())
|
||||
DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1));
|
||||
DG_HOST_ASSERT(sf.stride(-2) == 1);
|
||||
DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()));
|
||||
}
|
||||
|
||||
// Hopper SFB must be contiguous
|
||||
if (contiguous_check)
|
||||
DG_HOST_ASSERT(sf.is_contiguous());
|
||||
return sf;
|
||||
}
|
||||
|
||||
// Value matrix layout
|
||||
static int get_mk_alignment_for_contiguous_layout() {
|
||||
return 128;
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
25
csrc/utils/math.hpp
Normal file
25
csrc/utils/math.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/python.h>
|
||||
|
||||
#include "exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <typename T>
|
||||
static T ceil_div(const T& a, const T& b) {
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr T align(const T& a, const T& b) {
|
||||
return ceil_div(a, b) * b;
|
||||
}
|
||||
|
||||
static int get_tma_aligned_size(const int& x, const int& element_size) {
|
||||
constexpr int kNumTMAAlignmentBytes = 16;
|
||||
DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0);
|
||||
return align(x, kNumTMAAlignmentBytes / element_size);
|
||||
}
|
||||
|
||||
} // namespace deep_gemm
|
||||
70
csrc/utils/system.hpp
Normal file
70
csrc/utils/system.hpp
Normal file
@@ -0,0 +1,70 @@
|
||||
#pragma once
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
// ReSharper disable once CppNotAllPathsReturnValue
|
||||
template <typename dtype_t>
|
||||
static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) {
|
||||
const auto& c_str = std::getenv(name.c_str());
|
||||
if (c_str == nullptr)
|
||||
return default_value;
|
||||
|
||||
// Read the env and convert to the desired type
|
||||
if constexpr (std::is_same_v<dtype_t, std::string>) {
|
||||
return std::string(c_str);
|
||||
} else if constexpr (std::is_same_v<dtype_t, int>) {
|
||||
int value;
|
||||
std::sscanf(c_str, "%d", &value);
|
||||
return value;
|
||||
} else {
|
||||
DG_HOST_ASSERT(false and "Unexpected type");
|
||||
}
|
||||
}
|
||||
|
||||
static std::tuple<int, std::string> call_external_command(std::string command) {
|
||||
command = command + " 2>&1";
|
||||
const auto& deleter = [](FILE* f) { if (f) pclose(f); };
|
||||
std::unique_ptr<FILE, decltype(deleter)> pipe(popen(command.c_str(), "r"), deleter);
|
||||
DG_HOST_ASSERT(pipe != nullptr);
|
||||
|
||||
std::array<char, 512> buffer;
|
||||
std::string output;
|
||||
while (fgets(buffer.data(), buffer.size(), pipe.get()))
|
||||
output += buffer.data();
|
||||
const auto& exit_code = WEXITSTATUS(pclose(pipe.release()));
|
||||
return {exit_code, output};
|
||||
}
|
||||
|
||||
static std::filesystem::path make_dirs(const std::filesystem::path& path) {
|
||||
// OK if existed
|
||||
std::error_code capture;
|
||||
const bool& created = std::filesystem::create_directories(path, capture);
|
||||
DG_HOST_ASSERT(created or capture.value() == 0);
|
||||
if (created and get_env<int>("DG_JIT_DEBUG"))
|
||||
printf("Create directory: %s\n", path.c_str());
|
||||
return path;
|
||||
}
|
||||
|
||||
static std::string get_uuid() {
|
||||
static std::random_device rd;
|
||||
static std::mt19937 gen([]() {
|
||||
return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count();
|
||||
}());
|
||||
static std::uniform_int_distribution<uint32_t> dist;
|
||||
|
||||
std::stringstream ss;
|
||||
ss << getpid() << "-"
|
||||
<< std::hex << std::setfill('0')
|
||||
<< std::setw(8) << dist(gen) << "-"
|
||||
<< std::setw(8) << dist(gen) << "-"
|
||||
<< std::setw(8) << dist(gen);
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
} // deep_gemm
|
||||
Reference in New Issue
Block a user