Make various updates and fixes:
- Add support for legacy CUDA versions; now compatible with CUDA 12.3 and newer - Add support for NVRTC compilation - Other fixes and code refactoring
This commit is contained in:
@@ -18,8 +18,8 @@ find_package(CUDAToolkit REQUIRED)
|
||||
find_package(pybind11 REQUIRED)
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CUDA_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
|
||||
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
|
||||
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
|
||||
|
||||
@@ -24,7 +24,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- [x] MoE scheduler with TMA multicast compatibility
|
||||
- [x] Fix TMA multicast compatibility for indivisible shapes
|
||||
- [x] Skip useless computation on M
|
||||
- [ ] NVRTC as a faster compiler
|
||||
- [x] NVRTC as a faster compiler
|
||||
- [ ] Sanitizer for testing
|
||||
- [x] Weight gradient kernels for dense models
|
||||
- [x] Weight gradient kernels for MoE models
|
||||
@@ -46,8 +46,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
|
||||
- Python 3.8 or higher
|
||||
- Compilers with C++20 support
|
||||
- CUDA Toolkit:
|
||||
- Currently, CUDA 12.8 or higher is required, but support for older versions may be added in the future
|
||||
- CUDA 12.8 or higher for SM90
|
||||
- CUDA 12.3 or higher for SM90
|
||||
- **We highly recommend 12.9 or higher for the best performance**
|
||||
- CUDA 12.9 or higher for SM100
|
||||
- PyTorch 2.1 or higher
|
||||
@@ -114,6 +113,8 @@ The library provides some utility functions besides the above kernels:
|
||||
|
||||
- `deep_gemm.set_num_sms`: set the maximum SM count to use
|
||||
- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set)
|
||||
- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio
|
||||
- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio
|
||||
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout
|
||||
- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size
|
||||
- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout
|
||||
|
||||
12
build.sh
Executable file
12
build.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
# Change current directory into project root
|
||||
original_dir=$(pwd)
|
||||
script_dir=$(realpath "$(dirname "$0")")
|
||||
cd "$script_dir"
|
||||
|
||||
# Remove old dist file, build files, and install
|
||||
rm -rf build dist
|
||||
rm -rf *.egg-info
|
||||
python setup.py bdist_wheel
|
||||
|
||||
# Open users' original directory
|
||||
cd "$original_dir"
|
||||
@@ -9,7 +9,7 @@
|
||||
namespace deep_gemm {
|
||||
|
||||
class KernelRuntimeCache {
|
||||
std::unordered_map<std::filesystem::path, std::shared_ptr<KernelRuntime>> cache;
|
||||
std::unordered_map<std::string, std::shared_ptr<KernelRuntime>> cache;
|
||||
|
||||
public:
|
||||
// TODO: consider cache capacity
|
||||
|
||||
@@ -1,14 +1,17 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <filesystem>
|
||||
#include <fstream>
|
||||
#include <nvrtc.h>
|
||||
#include <regex>
|
||||
#include <string>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
#include "../utils/format.hpp"
|
||||
#include "../utils/hash.hpp"
|
||||
#include "../utils/lazy_init.hpp"
|
||||
#include "../utils/system.hpp"
|
||||
#include "cache.hpp"
|
||||
#include "device_runtime.hpp"
|
||||
@@ -16,10 +19,13 @@
|
||||
namespace deep_gemm {
|
||||
|
||||
class Compiler {
|
||||
std::string library_version;
|
||||
std::filesystem::path library_root_path;
|
||||
public:
|
||||
static std::filesystem::path library_root_path;
|
||||
static std::filesystem::path library_include_path;
|
||||
static std::filesystem::path cuda_home;
|
||||
static std::string library_version;
|
||||
|
||||
std::string get_library_version() const {
|
||||
static std::string get_library_version() {
|
||||
std::stringstream ss;
|
||||
for (const auto& f: collect_files(library_include_path / "deep_gemm")) {
|
||||
std::ifstream in(f, std::ios::binary);
|
||||
@@ -28,16 +34,23 @@ class Compiler {
|
||||
return get_hex_digest(ss.str());
|
||||
}
|
||||
|
||||
public:
|
||||
static void prepare_init(const std::string& library_root_path,
|
||||
const std::string& cuda_home_path_by_torch) {
|
||||
Compiler::library_root_path = library_root_path;
|
||||
Compiler::library_include_path = Compiler::library_root_path / "include";
|
||||
Compiler::cuda_home = cuda_home_path_by_torch;
|
||||
Compiler::library_version = get_library_version();
|
||||
}
|
||||
|
||||
std::string signature, flags;
|
||||
std::filesystem::path library_include_path;
|
||||
std::filesystem::path cache_dir_path;
|
||||
|
||||
explicit Compiler(const std::filesystem::path& library_root_path) {
|
||||
// Static library paths
|
||||
this->library_root_path = library_root_path;
|
||||
this->library_include_path = library_root_path / "include";
|
||||
this->library_version = get_library_version();
|
||||
Compiler() {
|
||||
// Check `prepare_init`
|
||||
DG_HOST_ASSERT(not library_root_path.empty());
|
||||
DG_HOST_ASSERT(not library_include_path.empty());
|
||||
DG_HOST_ASSERT(not cuda_home.empty());
|
||||
DG_HOST_ASSERT(not library_version.empty());
|
||||
|
||||
// Cache settings
|
||||
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
|
||||
@@ -46,10 +59,11 @@ public:
|
||||
|
||||
// The compiler flags applied to all derived compilers
|
||||
signature = "unknown-compiler";
|
||||
std::string ptxas_flags = "--ptxas-options=--register-usage-level=10";
|
||||
if (get_env<int>("DG_JIT_PTXAS_VERBOSE", 0))
|
||||
ptxas_flags += ",--verbose";
|
||||
flags = fmt::format("-std=c++20 --diag-suppress=39,161,174,177,186,940 {}", ptxas_flags);
|
||||
flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 "
|
||||
"--ptxas-options=--register-usage-level=10",
|
||||
get_env<int>("DG_JIT_CPP_STANDARD", 20));
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
|
||||
flags += " --ptxas-options=--verbose";
|
||||
}
|
||||
|
||||
virtual ~Compiler() = default;
|
||||
@@ -102,6 +116,11 @@ public:
|
||||
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0;
|
||||
};
|
||||
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path);
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path);
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home);
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version);
|
||||
|
||||
class NVCCCompiler final: public Compiler {
|
||||
std::filesystem::path nvcc_path;
|
||||
|
||||
@@ -125,11 +144,9 @@ class NVCCCompiler final: public Compiler {
|
||||
}
|
||||
|
||||
public:
|
||||
NVCCCompiler(const std::filesystem::path& library_root_path,
|
||||
const std::filesystem::path& cuda_home_path_by_torch):
|
||||
Compiler(library_root_path) {
|
||||
NVCCCompiler() {
|
||||
// Override the compiler signature
|
||||
nvcc_path = cuda_home_path_by_torch / "bin" / "nvcc";
|
||||
nvcc_path = cuda_home / "bin" / "nvcc";
|
||||
if (const auto& env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
|
||||
nvcc_path = env_nvcc_path;
|
||||
const auto& [nvcc_major, nvcc_minor] = get_nvcc_version();
|
||||
@@ -150,10 +167,10 @@ public:
|
||||
// Compile
|
||||
const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
|
||||
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
|
||||
printf("Running NVCC command: %s", command.c_str());
|
||||
printf("Running NVCC command: %s\n", command.c_str());
|
||||
const auto& [return_code, output] = call_external_command(command);
|
||||
if (return_code != 0) {
|
||||
printf("NVCC compilation failed: %s", output.c_str());
|
||||
printf("NVCC compilation failed: %s\n", output.c_str());
|
||||
DG_HOST_ASSERT(false and "NVCC compilation failed");
|
||||
}
|
||||
|
||||
@@ -163,6 +180,96 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
static std::shared_ptr<Compiler> compiler = nullptr;
|
||||
class NVRTCCompiler final: public Compiler {
|
||||
public:
|
||||
NVRTCCompiler() {
|
||||
// Override the compiler signature
|
||||
int major, minor;
|
||||
DG_NVRTC_CHECK(nvrtcVersion(&major, &minor));
|
||||
signature = fmt::format("NVRTC{}.{}", major, minor);
|
||||
|
||||
// Build include directories list
|
||||
std::string include_dirs;
|
||||
include_dirs += fmt::format("-I{} ", library_include_path.string());
|
||||
include_dirs += fmt::format("-I{} ", (cuda_home / "include").string());
|
||||
|
||||
// Add PCH support for version 12.8 and above
|
||||
// NOTES: PCH is vital for compilation speed
|
||||
std::string pch_flags;
|
||||
if (major > 12 or (major == 12 and minor >= 8)) {
|
||||
pch_flags = "--pch ";
|
||||
if (get_env<int>("DG_JIT_DEBUG", 0))
|
||||
pch_flags += "--pch-verbose=true ";
|
||||
}
|
||||
|
||||
// Override the compiler flags
|
||||
flags = fmt::format("{} {}--gpu-architecture=sm_{}a -default-device {}",
|
||||
flags, include_dirs, device_runtime->get_arch(), pch_flags);
|
||||
}
|
||||
|
||||
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
|
||||
// Write the code into the cache directory
|
||||
const auto& code_path = dir_path / "kernel.cu";
|
||||
put(code_path, code);
|
||||
|
||||
// Parse compilation options
|
||||
std::istringstream iss(flags);
|
||||
std::vector<std::string> options;
|
||||
std::string option;
|
||||
while (iss >> option)
|
||||
options.push_back(option);
|
||||
|
||||
// Convert to C-style string array for NVRTC
|
||||
std::vector<const char*> option_cstrs;
|
||||
for (const auto& opt: options)
|
||||
option_cstrs.push_back(opt.c_str());
|
||||
|
||||
// Print compiler command if requested
|
||||
if (get_env<int>("DG_JIT_DEBUG", 0) or get_env<int>("DG_JIT_PRINT_COMPILER_COMMAND", 0)) {
|
||||
printf("Compiling JIT runtime with NVRTC options: ");
|
||||
for (const auto& opt: options)
|
||||
printf("%s ", opt.c_str());
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
// Create NVRTC program and compile
|
||||
nvrtcProgram program;
|
||||
DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr));
|
||||
const auto& compile_result = nvrtcCompileProgram(program, static_cast<int>(option_cstrs.size()), option_cstrs.data());
|
||||
|
||||
// Get and print compiler log
|
||||
size_t log_size;
|
||||
DG_NVRTC_CHECK(nvrtcGetProgramLogSize(program, &log_size));
|
||||
if (get_env<int>("DG_JIT_DEBUG", 0) or compile_result != NVRTC_SUCCESS) {
|
||||
if (compile_result != NVRTC_SUCCESS)
|
||||
DG_HOST_ASSERT(log_size > 1);
|
||||
if (log_size > 1) {
|
||||
std::string compilation_log(log_size, '\0');
|
||||
DG_NVRTC_CHECK(nvrtcGetProgramLog(program, compilation_log.data()));
|
||||
printf("NVRTC log: %s\n", compilation_log.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// Get CUBIN size and data
|
||||
size_t cubin_size;
|
||||
DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size));
|
||||
std::string cubin_data(cubin_size, '\0');
|
||||
DG_NVRTC_CHECK(nvrtcGetCUBIN(program, cubin_data.data()));
|
||||
|
||||
// Write into the file system
|
||||
put(cubin_path, cubin_data);
|
||||
|
||||
// Cleanup
|
||||
DG_NVRTC_CHECK(nvrtcDestroyProgram(&program));
|
||||
}
|
||||
};
|
||||
|
||||
static auto compiler = LazyInit<Compiler>([]() -> std::shared_ptr<Compiler> {
|
||||
if (get_env<int>("DG_JIT_USE_NVRTC", 0)) {
|
||||
return std::make_shared<NVRTCCompiler>();
|
||||
} else {
|
||||
return std::make_shared<NVCCCompiler>();
|
||||
}
|
||||
});
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
#include "../utils/lazy_init.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
class DeviceRuntime {
|
||||
int num_sms = 0;
|
||||
int num_sms = 0, tc_util = 0;
|
||||
std::shared_ptr<cudaDeviceProp> cached_prop;
|
||||
|
||||
public:
|
||||
@@ -43,8 +44,17 @@ public:
|
||||
num_sms = get_prop()->multiProcessorCount;
|
||||
return num_sms;
|
||||
}
|
||||
|
||||
void set_tc_util(const int& new_tc_util) {
|
||||
DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100);
|
||||
tc_util = new_tc_util;
|
||||
}
|
||||
|
||||
int get_tc_util() const {
|
||||
return tc_util == 0 ? 100 : tc_util;
|
||||
}
|
||||
};
|
||||
|
||||
static auto device_runtime = std::make_shared<DeviceRuntime>();
|
||||
static auto device_runtime = LazyInit<DeviceRuntime>([](){ return std::make_shared<DeviceRuntime>(); });
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
135
csrc/jit/handle.hpp
Normal file
135
csrc/jit/handle.hpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
#if CUDART_VERSION >= 12080 or defined(DG_JIT_USE_DRIVER_API)
|
||||
|
||||
// Use CUDA runtime API
|
||||
using LibraryHandle = cudaLibrary_t;
|
||||
using KernelHandle = cudaKernel_t;
|
||||
using LaunchConfigHandle = cudaLaunchConfig_t;
|
||||
using LaunchAttrHandle = cudaLaunchAttribute;
|
||||
|
||||
#define DG_CUDA_UNIFIED_CHECK DG_CUDA_RUNTIME_CHECK
|
||||
|
||||
static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name,
|
||||
LibraryHandle *library_opt = nullptr) {
|
||||
LibraryHandle library;
|
||||
KernelHandle kernel{};
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, func_name.c_str()));
|
||||
|
||||
if (library_opt != nullptr)
|
||||
*library_opt = library;
|
||||
return kernel;
|
||||
}
|
||||
|
||||
static void unload_library(const LibraryHandle& library) {
|
||||
const auto& error = cudaLibraryUnload(library);
|
||||
DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading);
|
||||
}
|
||||
|
||||
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
|
||||
const cudaStream_t& stream, const int& smem_size,
|
||||
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
|
||||
if (smem_size > 0)
|
||||
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
|
||||
LaunchConfigHandle config;
|
||||
config.gridDim = grid_dim;
|
||||
config.blockDim = block_dim;
|
||||
config.dynamicSmemBytes = smem_size;
|
||||
config.stream = stream;
|
||||
config.numAttrs = 0;
|
||||
config.attrs = nullptr;
|
||||
|
||||
// NOTES: must use `static` or the `attr` will be deconstructed
|
||||
static LaunchAttrHandle attr;
|
||||
if (cluster_dim > 1) {
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {static_cast<unsigned>(cluster_dim), 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
template<typename... ActTypes>
|
||||
static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) {
|
||||
void *ptr_args[] = { &args... };
|
||||
return cudaLaunchKernelExC(&config, kernel, ptr_args);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
// Use CUDA driver API
|
||||
using LibraryHandle = CUmodule;
|
||||
using KernelHandle = CUfunction;
|
||||
using LaunchConfigHandle = CUlaunchConfig;
|
||||
using LaunchAttrHandle = CUlaunchAttribute;
|
||||
|
||||
#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK
|
||||
|
||||
static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name,
|
||||
LibraryHandle *library_opt = nullptr) {
|
||||
LibraryHandle library;
|
||||
KernelHandle kernel;
|
||||
DG_CUDA_DRIVER_CHECK(cuModuleLoad(&library, cubin_path.c_str()));
|
||||
DG_CUDA_DRIVER_CHECK(cuModuleGetFunction(&kernel, library, func_name.c_str()));
|
||||
|
||||
if (library_opt != nullptr)
|
||||
*library_opt = library;
|
||||
return kernel;
|
||||
}
|
||||
|
||||
static void unload_library(const LibraryHandle& library) {
|
||||
const auto& error = cuModuleUnload(library);
|
||||
DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED);
|
||||
}
|
||||
|
||||
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
|
||||
const cudaStream_t& stream, const int& smem_size,
|
||||
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
|
||||
if (smem_size > 0)
|
||||
DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size));
|
||||
|
||||
LaunchConfigHandle config;
|
||||
config.gridDimX = grid_dim.x;
|
||||
config.gridDimY = grid_dim.y;
|
||||
config.gridDimZ = grid_dim.z;
|
||||
config.blockDimX = block_dim.x;
|
||||
config.blockDimY = block_dim.y;
|
||||
config.blockDimZ = block_dim.z;
|
||||
config.sharedMemBytes = smem_size;
|
||||
config.hStream = stream;
|
||||
config.numAttrs = 0;
|
||||
config.attrs = nullptr;
|
||||
|
||||
// NOTES: must use `static` or the `attr` will be deconstructed
|
||||
static LaunchAttrHandle attr;
|
||||
if (cluster_dim > 1) {
|
||||
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
|
||||
attr.value.clusterDim.x = cluster_dim;
|
||||
attr.value.clusterDim.y = 1;
|
||||
attr.value.clusterDim.z = 1;
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
}
|
||||
return config;
|
||||
}
|
||||
|
||||
template<typename... ActTypes>
|
||||
static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) {
|
||||
void *ptr_args[] = { &args... };
|
||||
return cuLaunchKernelEx(&config, kernel, ptr_args, nullptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -1,12 +1,10 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <filesystem>
|
||||
|
||||
#include "../utils/exception.hpp"
|
||||
#include "../utils/format.hpp"
|
||||
#include "../utils/system.hpp"
|
||||
#include "device_runtime.hpp"
|
||||
#include "handle.hpp"
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -23,19 +21,17 @@ struct LaunchArgs {
|
||||
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasLaunchArgs = requires (const T& t) {
|
||||
{ t.launch_args } -> std::convertible_to<decltype(t.launch_args)>;
|
||||
};
|
||||
|
||||
class KernelRuntime final {
|
||||
public:
|
||||
static std::filesystem::path cuda_home;
|
||||
|
||||
cudaLibrary_t library;
|
||||
cudaKernel_t kernel;
|
||||
LibraryHandle library;
|
||||
KernelHandle kernel;
|
||||
|
||||
explicit KernelRuntime(const std::filesystem::path& dir_path) {
|
||||
// Check `prepare_init`
|
||||
DG_HOST_ASSERT(not cuda_home.empty());
|
||||
|
||||
// NOLINT(*-pro-type-member-init)
|
||||
const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump";
|
||||
const auto& cubin_path = dir_path / "kernel.cubin";
|
||||
@@ -50,7 +46,8 @@ public:
|
||||
std::istringstream iss(symbols);
|
||||
std::vector<std::string> symbol_names;
|
||||
for (std::string line; std::getline(iss, line); ) {
|
||||
if (line.find("STT_FUNC") == 0 and std::ranges::none_of(illegal_names, [&](const auto& name) { return line.find(name) != std::string::npos; })) {
|
||||
if (line.find("STT_FUNC") == 0 and std::none_of(illegal_names.begin(), illegal_names.end(),
|
||||
[&](const auto& name) { return line.find(name) != std::string::npos; })) {
|
||||
const auto& last_space = line.rfind(' ');
|
||||
symbol_names.push_back(line.substr(last_space + 1));
|
||||
}
|
||||
@@ -64,11 +61,10 @@ public:
|
||||
|
||||
// Load from the library
|
||||
DG_HOST_ASSERT(symbol_names.size() == 1);
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, symbol_names[0].c_str()));
|
||||
kernel = load_kernel(cubin_path, symbol_names[0], &library);
|
||||
}
|
||||
|
||||
static void set_cuda_home(const std::string& cuda_home_path_by_torch) {
|
||||
static void prepare_init(const std::string& cuda_home_path_by_torch) {
|
||||
cuda_home = cuda_home_path_by_torch;
|
||||
}
|
||||
|
||||
@@ -78,18 +74,16 @@ public:
|
||||
}
|
||||
|
||||
~KernelRuntime() noexcept(false) {
|
||||
const auto& error = cudaLibraryUnload(library);
|
||||
DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading);
|
||||
unload_library(library);
|
||||
}
|
||||
};
|
||||
|
||||
// Declare after defining
|
||||
decltype(KernelRuntime::cuda_home) KernelRuntime::cuda_home;
|
||||
DG_DECLARE_STATIC_VAR_IN_CLASS(KernelRuntime, cuda_home);
|
||||
|
||||
template <typename Derived>
|
||||
class LaunchRuntime {
|
||||
public:
|
||||
template <typename Args> requires HasLaunchArgs<Args>
|
||||
template <typename Args>
|
||||
static std::string generate(const Args& args) {
|
||||
const auto& code = Derived::generate_impl(args);
|
||||
if (get_env<int>("DG_JIT_DEBUG", 0))
|
||||
@@ -97,34 +91,18 @@ public:
|
||||
return code;
|
||||
}
|
||||
|
||||
template <typename Args> requires HasLaunchArgs<Args>
|
||||
template <typename Args>
|
||||
static void launch(const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) {
|
||||
const auto& kernel = kernel_runtime->kernel;
|
||||
const auto& stream = at::cuda::getCurrentCUDAStream();
|
||||
const LaunchArgs& launch_args = args.launch_args;
|
||||
|
||||
// Set dynamic shared memory size
|
||||
if (launch_args.smem_size > 0)
|
||||
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, launch_args.smem_size));
|
||||
|
||||
// Launch config
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = {static_cast<unsigned>(launch_args.grid_dim.first),
|
||||
static_cast<unsigned>(launch_args.grid_dim.second),
|
||||
1};
|
||||
config.blockDim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
|
||||
config.dynamicSmemBytes = launch_args.smem_size;
|
||||
config.stream = stream;
|
||||
config.numAttrs = 0;
|
||||
|
||||
// Clusters
|
||||
cudaLaunchAttribute attr;
|
||||
if (launch_args.cluster_dim > 1) {
|
||||
attr.id = cudaLaunchAttributeClusterDimension;
|
||||
attr.val.clusterDim = {static_cast<unsigned>(launch_args.cluster_dim), 1, 1};
|
||||
config.attrs = &attr;
|
||||
config.numAttrs = 1;
|
||||
}
|
||||
const dim3& grid_dim = {static_cast<unsigned>(launch_args.grid_dim.first),
|
||||
static_cast<unsigned>(launch_args.grid_dim.second),
|
||||
1};
|
||||
const dim3& block_dim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
|
||||
auto config = construct_launch_config(kernel, stream, launch_args.smem_size,
|
||||
grid_dim, block_dim, launch_args.cluster_dim);
|
||||
|
||||
// Launch in the derived class
|
||||
if (get_env<int>("DG_JIT_DEBUG")) {
|
||||
|
||||
@@ -62,8 +62,9 @@ struct GemmConfig {
|
||||
int block_m, block_n, block_k;
|
||||
int num_stages, num_last_stages;
|
||||
|
||||
// Runtime configs
|
||||
// Templated device configs
|
||||
int num_sms;
|
||||
int tc_util;
|
||||
|
||||
// Structured configs
|
||||
MulticastConfig multicast_config;
|
||||
@@ -265,30 +266,35 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k
|
||||
.num_stages = best_num_stages,
|
||||
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
|
||||
.num_sms = num_min_sms,
|
||||
.tc_util = device_runtime->get_tc_util(),
|
||||
.multicast_config = best_multicast_config,
|
||||
// ReSharper disable once CppLocalVariableMightNotBeInitialized
|
||||
.smem_config = best_smem_config,
|
||||
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
|
||||
};
|
||||
|
||||
// Only SM100 BF16 kernels support tensor core control
|
||||
if (config.tc_util < 100)
|
||||
DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and ab_dtype == torch::kBFloat16);
|
||||
|
||||
// Print configs for the first time
|
||||
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
|
||||
auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b,
|
||||
ab_dtype, cd_dtype, with_accumulation, num_sms);
|
||||
static std::set<decltype(key)> printed;
|
||||
if (not printed.contains(key)) {
|
||||
printf("Gemm type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
|
||||
if (printed.count(key) == 0) {
|
||||
printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
|
||||
"A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, "
|
||||
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
|
||||
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
|
||||
"swizzle B: %d, swizzle CD: %d, threads: %d\n",
|
||||
"swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n",
|
||||
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
|
||||
static_cast<int>(major_a), static_cast<int>(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype),
|
||||
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
|
||||
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
|
||||
static_cast<int>(best_multicast_config.is_multicast_on_a),
|
||||
best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode,
|
||||
best_smem_config.swizzle_cd_mode, config.thread_config.num_threads);
|
||||
best_smem_config.swizzle_cd_mode, config.num_sms, config.thread_config.num_threads, config.tc_util);
|
||||
printed.insert(key);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,6 +44,11 @@ struct SM100ArchSpec {
|
||||
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
|
||||
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
const int& block_m, const int& block_n) {
|
||||
// TODO: consider more carefully for BF16 GEMMs
|
||||
// 2SM BF16 UMMA does not support `N % 32 != 0`
|
||||
if (ab_dtype == torch::kBFloat16 and block_n % 32 != 0)
|
||||
return false;
|
||||
|
||||
// Layout A/D does not support `block_m == 64` and `block_n % 16 != 0`
|
||||
if (block_m == 64 or block_n % 16 != 0)
|
||||
return false;
|
||||
@@ -68,7 +73,7 @@ struct SM100ArchSpec {
|
||||
|
||||
// NOTES: when B is MN-major, we restrict `block_n` to multiples of 64,
|
||||
// since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA
|
||||
return major_b == cute::UMMA::Major::K or block_n % 64 == 0;
|
||||
return major_b == cute::UMMA::Major::K or (block_n * c10::elementSize(ab_dtype)) % 64 == 0;
|
||||
}
|
||||
|
||||
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
|
||||
@@ -93,7 +98,7 @@ struct SM100ArchSpec {
|
||||
|
||||
static ThreadConfig get_thread_config(const KernelType& kernel_type,
|
||||
const int& block_m, const int& block_n) {
|
||||
return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D1D ? 128 : block_m);
|
||||
return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D2D ? block_m : 128);
|
||||
}
|
||||
|
||||
static int get_smem_cd_size(const KernelType& kernel_type,
|
||||
|
||||
@@ -32,13 +32,6 @@ public:
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
@@ -54,7 +47,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
{}, {}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -66,14 +59,13 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type),
|
||||
args.gemm_config.with_accumulation,
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.grouped_layout, args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
args.tensor_map_sfa, args.tensor_map_sfb,
|
||||
@@ -286,7 +278,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
const auto& num_groups = static_cast<int>(ks.size());
|
||||
|
||||
// Get config using max K for better performance
|
||||
const auto& max_k = *std::ranges::max_element(ks);
|
||||
const auto& max_k = *std::max_element(ks.begin(), ks.end());
|
||||
const auto& config = get_best_config<SM100ArchSpec>(
|
||||
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
|
||||
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
|
||||
@@ -316,9 +308,9 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
|
||||
static_cast<int>(cd.stride(1)), num_groups,
|
||||
config.smem_config.swizzle_cd_mode);
|
||||
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512,
|
||||
config.block_m, config.block_k, num_groups, 0);
|
||||
config.block_m, config.block_k, 1, 0);
|
||||
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512,
|
||||
config.block_n, config.block_k, num_groups, 0);
|
||||
config.block_n, config.block_k, 1, 0);
|
||||
|
||||
// Duplicate the accumulator if necessary
|
||||
if (c.has_value()) {
|
||||
|
||||
@@ -30,13 +30,6 @@ public:
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
@@ -51,6 +44,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{},
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
@@ -63,13 +57,13 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type),
|
||||
to_string(args.gemm_config.cd_dtype));
|
||||
args.gemm_config.num_sms,
|
||||
to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.sfb, args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
|
||||
@@ -29,13 +29,6 @@ public:
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
@@ -49,7 +42,7 @@ static void __instantiate_kernel() {{
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}, {},
|
||||
{}
|
||||
{}, {}
|
||||
>);
|
||||
}};
|
||||
)",
|
||||
@@ -61,12 +54,12 @@ static void __instantiate_kernel() {{
|
||||
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
|
||||
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
|
||||
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
|
||||
to_string(args.gemm_config.gemm_type));
|
||||
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type));
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
// TODO: optimize `args` copy
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.sfb, args.grouped_layout,
|
||||
args.m, args.n, args.k,
|
||||
args.tensor_map_a, args.tensor_map_b,
|
||||
|
||||
@@ -22,13 +22,6 @@ public:
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
@@ -41,8 +34,8 @@ static void __instantiate_kernel() {{
|
||||
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast<uint32_t>(args.mn)));
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast<uint32_t>(args.mn)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -58,13 +51,6 @@ public:
|
||||
|
||||
static std::string generate_impl(const Args& args) {
|
||||
return fmt::format(R"(
|
||||
#ifdef __CUDACC_RTC__
|
||||
#include <deep_gemm/nvrtc_std.cuh>
|
||||
#else
|
||||
#include <cuda.h>
|
||||
#include <string>
|
||||
#endif
|
||||
|
||||
#include <deep_gemm/impls/smxx_layout.cuh>
|
||||
|
||||
using namespace deep_gemm;
|
||||
@@ -77,8 +63,8 @@ static void __instantiate_kernel() {{
|
||||
)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k);
|
||||
}
|
||||
|
||||
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
|
||||
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
|
||||
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
|
||||
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
|
||||
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k));
|
||||
}
|
||||
};
|
||||
@@ -108,43 +94,16 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
|
||||
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) {
|
||||
const auto& sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf;
|
||||
|
||||
// First, convert into UE8M0 `uint8_t`
|
||||
const auto& ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8);
|
||||
|
||||
// Second, make padded packed tensors
|
||||
const auto& [num_groups, mn, k] = get_shape<3>(sf_reshaped);
|
||||
const auto& aligned_mn = get_tma_aligned_size(mn, 4);
|
||||
const auto& aligned_k = align(k, 4);
|
||||
|
||||
const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8);
|
||||
auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options);
|
||||
// ReSharper disable once CppExpressionWithoutSideEffects
|
||||
padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor);
|
||||
padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4});
|
||||
|
||||
// Finally, transpose
|
||||
auto out = torch::empty_strided({num_groups, aligned_mn, aligned_k / 4},
|
||||
{aligned_mn * (aligned_k / 4), 1, aligned_mn},
|
||||
at::TensorOptions().device(sf.device()).dtype(torch::kInt32));
|
||||
out = out.copy_(padded).slice(1, 0, mn);
|
||||
return (sf.dim() == 2) ? out.squeeze(0) : out;
|
||||
}
|
||||
|
||||
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
|
||||
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
|
||||
const auto& packed_sf_k = ceil_div(sf_k, 4);
|
||||
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
|
||||
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
|
||||
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
|
||||
DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0);
|
||||
|
||||
// Launch the kernel
|
||||
if (batched_sf.is_contiguous()) {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if ((mn * sf_k) % 4 != 0 and num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
constexpr int block_mn = 48;
|
||||
constexpr int num_threads = 512;
|
||||
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
|
||||
@@ -160,10 +119,6 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
|
||||
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
|
||||
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
|
||||
} else {
|
||||
// Fallback to slow PyTorch impl for non-supported cases
|
||||
if (mn % 4 != 0 or num_groups > 1)
|
||||
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
|
||||
|
||||
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
|
||||
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
|
||||
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
#include "jit/device_runtime.hpp"
|
||||
#include "utils/layout.hpp"
|
||||
|
||||
#include "jit_kernels/impls/smxx_layout.hpp"
|
||||
#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
|
||||
#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
|
||||
#include "jit_kernels/impls/smxx_layout.hpp"
|
||||
|
||||
#ifndef TORCH_EXTENSION_NAME
|
||||
#define TORCH_EXTENSION_NAME deep_gemm_cpp
|
||||
@@ -17,8 +17,8 @@
|
||||
namespace deep_gemm {
|
||||
torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
|
||||
const int& mn, const int& k,
|
||||
const std::optional<int>& num_groups,
|
||||
const std::tuple<int, int, int>& recipe,
|
||||
const std::optional<int>& num_groups,
|
||||
const bool& is_sfa,
|
||||
const bool& disable_ue8m0_cast) {
|
||||
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
|
||||
@@ -121,8 +121,8 @@ void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, std::nullopt, recipe.value(), false, disable_ue8m0_cast);
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch into different implements
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
@@ -133,7 +133,7 @@ void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
|
||||
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
|
||||
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,8 +208,8 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tens
|
||||
// Transform SFA and SFB into compute-required layout
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
@@ -223,7 +223,7 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tens
|
||||
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
|
||||
num_groups, m, n, k, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,8 +271,8 @@ void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>&
|
||||
// Transform scaling factors
|
||||
if (not recipe.has_value())
|
||||
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, num_groups, recipe.value(), true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
|
||||
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast);
|
||||
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast);
|
||||
|
||||
// Dispatch implementation
|
||||
const auto& arch_major = device_runtime->get_arch_major();
|
||||
@@ -286,7 +286,7 @@ void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>&
|
||||
sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
|
||||
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
|
||||
} else {
|
||||
DG_HOST_UNREACHABLE("Unsupported kernel or scaling factor types");
|
||||
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -339,18 +339,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.doc() = "DeepGEMM C++ library";
|
||||
|
||||
// Runtime
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
});
|
||||
m.def("get_num_sms", [&]() {
|
||||
return device_runtime->get_num_sms();
|
||||
});
|
||||
m.def("set_num_sms", [&](const int& new_num_sms) {
|
||||
device_runtime->set_num_sms(new_num_sms);
|
||||
m.def("set_tc_util", [&](const int& new_tc_util) {
|
||||
device_runtime->set_tc_util(new_tc_util);
|
||||
});
|
||||
m.def("get_tc_util", [&]() {
|
||||
return device_runtime->get_tc_util();
|
||||
});
|
||||
|
||||
// JIT
|
||||
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) {
|
||||
DG_HOST_ASSERT(get_env("DG_JIT_USE_NVRTC", 0) == 0 and "Currently only support NVCC");
|
||||
compiler = std::make_shared<NVCCCompiler>(library_root_path, cuda_home_path_by_torch);
|
||||
KernelRuntime::set_cuda_home(cuda_home_path_by_torch);
|
||||
Compiler::prepare_init(library_root_path, cuda_home_path_by_torch);
|
||||
KernelRuntime::prepare_init(cuda_home_path_by_torch);
|
||||
});
|
||||
|
||||
// Stable kernel APIs with automatic arch/layout dispatch
|
||||
@@ -391,7 +396,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
|
||||
py::arg("recipe") = std::make_tuple(1, 1, 128),
|
||||
py::arg("compiled_dims") = "mn");
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout);
|
||||
|
||||
// Layout kernels
|
||||
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
|
||||
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
|
||||
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
|
||||
py::arg("disable_ue8m0_cast") = false);
|
||||
|
||||
// Raw kernels or functions
|
||||
m.def("get_tma_aligned_size", &get_tma_aligned_size);
|
||||
|
||||
@@ -35,6 +35,16 @@ do { \
|
||||
#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason))
|
||||
#endif
|
||||
|
||||
#ifndef DG_NVRTC_CHECK
|
||||
#define DG_NVRTC_CHECK(cmd) \
|
||||
do { \
|
||||
const auto& e = (cmd); \
|
||||
if (e != NVRTC_SUCCESS) { \
|
||||
throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \
|
||||
} \
|
||||
} while (0)
|
||||
#endif
|
||||
|
||||
#ifndef DG_CUDA_DRIVER_CHECK
|
||||
#define DG_CUDA_DRIVER_CHECK(cmd) \
|
||||
do { \
|
||||
|
||||
27
csrc/utils/lazy_init.hpp
Normal file
27
csrc/utils/lazy_init.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#define DG_DECLARE_STATIC_VAR_IN_CLASS(cls, name) decltype(cls::name) cls::name
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
template <typename T>
|
||||
class LazyInit {
|
||||
public:
|
||||
explicit LazyInit(std::function<std::shared_ptr<T>()> factory)
|
||||
: factory(std::move(factory)) {}
|
||||
|
||||
T* operator -> () {
|
||||
if (ptr == nullptr)
|
||||
ptr = factory();
|
||||
return ptr.get();
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<T> ptr;
|
||||
std::function<std::shared_ptr<T>()> factory;
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
@@ -38,8 +38,8 @@ static std::tuple<int, std::string> call_external_command(std::string command) {
|
||||
std::string output;
|
||||
while (fgets(buffer.data(), buffer.size(), pipe.get()))
|
||||
output += buffer.data();
|
||||
const auto exit_code = pclose(pipe.release());
|
||||
return {WEXITSTATUS(exit_code), output};
|
||||
const auto& exit_code = WEXITSTATUS(pclose(pipe.release()));
|
||||
return {exit_code, output};
|
||||
}
|
||||
|
||||
static std::vector<std::filesystem::path> collect_files(const std::filesystem::path& root) {
|
||||
|
||||
@@ -22,17 +22,22 @@ deep_gemm_cpp.init(
|
||||
# Configs
|
||||
from deep_gemm_cpp import (
|
||||
set_num_sms,
|
||||
get_num_sms
|
||||
get_num_sms,
|
||||
set_tc_util,
|
||||
get_tc_util,
|
||||
)
|
||||
|
||||
# Kernels
|
||||
from deep_gemm_cpp import (
|
||||
# FP8 GEMMs
|
||||
fp8_gemm_nt, fp8_gemm_nn,
|
||||
fp8_gemm_tn, fp8_gemm_tt,
|
||||
m_grouped_fp8_gemm_nt_contiguous,
|
||||
m_grouped_fp8_gemm_nn_contiguous,
|
||||
fp8_m_grouped_gemm_nt_masked,
|
||||
k_grouped_fp8_gemm_tn_contiguous
|
||||
k_grouped_fp8_gemm_tn_contiguous,
|
||||
# Layout kernels
|
||||
transform_sf_into_required_layout
|
||||
)
|
||||
|
||||
# Some utils
|
||||
|
||||
48
deep_gemm/include/deep_gemm/common/cute_tie.cuh
Normal file
48
deep_gemm/include/deep_gemm/common/cute_tie.cuh
Normal file
@@ -0,0 +1,48 @@
|
||||
#pragma once
|
||||
|
||||
namespace cute {
|
||||
|
||||
struct ignore_t {
|
||||
template <typename T>
|
||||
constexpr const ignore_t& operator=(T&&) const noexcept {
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
inline constexpr ignore_t ignore{};
|
||||
|
||||
} // namespace cute
|
||||
|
||||
#define CUTE_TIE_CONCAT_IMPL(A, B) A##B
|
||||
#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B)
|
||||
|
||||
#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N
|
||||
#define CUTE_TIE_COUNT_ARGS(...) \
|
||||
CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
|
||||
|
||||
#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get<I>(TUPLE)
|
||||
#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get<I>(TUPLE)
|
||||
|
||||
#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1);
|
||||
#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2);
|
||||
#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3);
|
||||
#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4);
|
||||
#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5);
|
||||
|
||||
#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \
|
||||
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
|
||||
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
|
||||
CUTE_TIE_OP_DECL, \
|
||||
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
|
||||
__VA_ARGS__ \
|
||||
)
|
||||
|
||||
#define CUTE_TIE(TUPLE_EXPR, ...) \
|
||||
do { \
|
||||
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
|
||||
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
|
||||
CUTE_TIE_OP_ASSIGN, \
|
||||
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
|
||||
__VA_ARGS__ \
|
||||
); \
|
||||
} while (0)
|
||||
@@ -11,14 +11,28 @@ enum class KGroupedIndexType {
|
||||
SF_K,
|
||||
};
|
||||
|
||||
template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool isMulticastOnA>
|
||||
static constexpr uint32_t get_num_1d_blocks_per_group() {
|
||||
// Select the best from candidates
|
||||
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
|
||||
for (const auto& candidate: {8u, 16u}) {
|
||||
const auto& usage = isMulticastOnA ?
|
||||
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
|
||||
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
|
||||
if (usage < min_usage)
|
||||
min_usage = usage, num_best_blocks = candidate;
|
||||
}
|
||||
return num_best_blocks;
|
||||
}
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
|
||||
template <GemmType kGemmType,
|
||||
uint32_t BLOCK_M, uint32_t BLOCK_N,
|
||||
uint32_t kNumGroups,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
// TODO: refactor this by other values
|
||||
uint32_t kNum1DBlocksPerGroup = 16>
|
||||
uint32_t kNumSMs,
|
||||
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
|
||||
struct Scheduler {
|
||||
int current_iter = -1;
|
||||
|
||||
@@ -88,6 +102,7 @@ struct Scheduler {
|
||||
#endif
|
||||
|
||||
// Convert to final M/N block indices
|
||||
// `kIsMulticastOnA == true` leads to groups on N
|
||||
if constexpr (kIsMulticastOnA) {
|
||||
m_block_idx = in_group_idx / num_blocks_in_group;
|
||||
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
|
||||
@@ -103,7 +118,7 @@ struct Scheduler {
|
||||
if constexpr (kGemmType == GemmType::Normal) {
|
||||
return block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
|
||||
const auto offset = kWithGroupOffset ? std::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
|
||||
const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
|
||||
return offset * shape_dim + block_idx * block_size;
|
||||
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
const auto offset = kWithGroupOffset ? current_group_idx : 0;
|
||||
@@ -123,7 +138,7 @@ struct Scheduler {
|
||||
}
|
||||
|
||||
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
|
||||
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
|
||||
const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x;
|
||||
|
||||
if constexpr (kGemmType == GemmType::MGroupedMasked) {
|
||||
while (true) {
|
||||
|
||||
@@ -101,7 +101,7 @@ constexpr uint32_t get_umma_desc_stride_k() {
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
|
||||
__device__ __forceinline__
|
||||
uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
|
||||
return base + ((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) >> 4u);
|
||||
return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
|
||||
}
|
||||
|
||||
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cute/arch/mma_sm90_gmma.hpp>
|
||||
#include <cute/arch/mma_sm90_gmma_ext.hpp>
|
||||
|
||||
@@ -10,13 +9,13 @@ template <int N_, typename MMA>
|
||||
struct FP8MMA {
|
||||
|
||||
template <size_t ...Idx>
|
||||
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence<Idx...>) {
|
||||
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
|
||||
using namespace cute::SM90::GMMA;
|
||||
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
|
||||
call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence<N_/2>{});
|
||||
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
|
||||
}
|
||||
|
||||
static constexpr int M = 64;
|
||||
@@ -139,7 +138,7 @@ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout
|
||||
|
||||
__device__ __forceinline__ void
|
||||
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
|
||||
const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast) {
|
||||
const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast = 1) {
|
||||
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
|
||||
if (num_tma_multicast == 1) {
|
||||
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
|
||||
|
||||
@@ -12,6 +12,7 @@ enum class GemmType {
|
||||
enum class KernelType {
|
||||
Kernel1D1D = 0,
|
||||
Kernel1D2D = 1,
|
||||
KernelNoSF = 2
|
||||
};
|
||||
|
||||
} // namespace deep_gemm
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda/std/cstdint>
|
||||
#include <cuda/std/utility>
|
||||
#include <cute/container/tuple.hpp>
|
||||
|
||||
#include "cute_tie.cuh"
|
||||
|
||||
#ifdef __CLION_IDE__
|
||||
|
||||
@@ -135,4 +140,8 @@ __device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
|
||||
return *reinterpret_cast<int*>(&bf16x2);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void prefetch_l1(void *ptr) {
|
||||
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
|
||||
}
|
||||
|
||||
} // namespace `deep_gemm`
|
||||
|
||||
@@ -20,22 +20,23 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfa,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfb,
|
||||
const __grid_constant__ CUtensorMap tensor_map_c,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d) {
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_c,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
// GEMM with accumulation must have FP32 output
|
||||
if constexpr (kWithAccumulation)
|
||||
DG_STATIC_ASSERT(std::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
||||
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
|
||||
|
||||
// Configs
|
||||
constexpr uint32_t LAYOUT_AD_M = 128;
|
||||
@@ -63,7 +64,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
// 2-CTA MMA
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t STORE_BLOCK_M = std::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
||||
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
|
||||
@@ -95,6 +96,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
||||
@@ -173,7 +175,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
// For pipeline unrolling
|
||||
struct DivisibleK {};
|
||||
@@ -207,7 +209,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
@@ -329,7 +331,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
@@ -342,7 +344,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
|
||||
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
|
||||
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
|
||||
using cute_utccp_t = std::conditional_t<kNumMulticast == 1,
|
||||
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
|
||||
|
||||
// SFA and SFB copy
|
||||
@@ -363,7 +365,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
__syncwarp();
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
using cute_mma_t = std::conditional_t<kNumMulticast == 1,
|
||||
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_MXF8F6F4_SS <cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
||||
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>,
|
||||
cute::SM100_MMA_MXF8F6F4_2x1SM_SS<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
|
||||
@@ -416,7 +418,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
|
||||
|
||||
#pragma unroll
|
||||
@@ -530,7 +532,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
|
||||
// Load from tensor memory, store into shared memory
|
||||
uint32_t values[kNumElemsPerBankGroup];
|
||||
if constexpr (std::is_same_v<cd_dtype_t, float>) {
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
|
||||
@@ -539,7 +541,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
|
||||
values[0], values[1], values[2], values[3],
|
||||
values[4], values[5], values[6], values[7]);
|
||||
@@ -564,7 +566,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
|
||||
cute::tma_store_fence();
|
||||
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
|
||||
if (epilogue_thread_idx == 0) {
|
||||
using cute_tma_t = std::conditional_t<kWithAccumulation,
|
||||
using cute_tma_t = cute::conditional_t<kWithAccumulation,
|
||||
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
|
||||
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
|
||||
cute::tma_store_arrive();
|
||||
|
||||
@@ -21,14 +21,15 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
|
||||
uint32_t kNumMulticast, bool kIsMulticastOnA,
|
||||
uint32_t kNumSMs,
|
||||
GemmType kGemmType, typename cd_dtype_t>
|
||||
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
|
||||
sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfa) {
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
|
||||
using Barrier = cutlass::arch::ClusterTransactionBarrier;
|
||||
|
||||
@@ -61,7 +62,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// 2-CTA MMA
|
||||
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
|
||||
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
|
||||
constexpr uint32_t STORE_BLOCK_M = std::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
|
||||
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
|
||||
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
|
||||
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
|
||||
@@ -87,6 +88,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == 0) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
@@ -171,7 +173,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
// Register configurations
|
||||
constexpr uint32_t kNumNonEpilogueRegisters = 64;
|
||||
@@ -187,7 +189,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
@@ -276,7 +278,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
@@ -293,7 +295,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Issue UMMA in the leader CTA
|
||||
if (s < kNumInnerStages) {
|
||||
using cute_mma_t = std::conditional_t<kNumMulticast == 1,
|
||||
using cute_mma_t = cute::conditional_t<kNumMulticast == 1,
|
||||
cute::SM100_MMA_F8F6F4_SS, cute::SM100_MMA_F8F6F4_2x1SM_SS>;
|
||||
tcgen05_after_thread_sync();
|
||||
#pragma unroll
|
||||
@@ -366,7 +368,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// Launch promotion
|
||||
float accum[BLOCK_N] = {0};
|
||||
launch_k_iterations([&](uint32_t k_iter, auto type) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
|
||||
|
||||
@@ -474,7 +476,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// Load from tensor memory, store into shared memory
|
||||
// NOTES: if you want to do accumulation, please notice that you need two accumulation barriers
|
||||
const auto offset = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup;
|
||||
if constexpr (std::is_same_v<cd_dtype_t, float>) {
|
||||
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
|
||||
// For FP32 output, read and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
|
||||
st_shared(smem_ptr,
|
||||
@@ -484,7 +486,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
*reinterpret_cast<uint32_t*>(&accum[offset + 3]));
|
||||
} else {
|
||||
// For BF16 output, read, cast and store
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
|
||||
st_shared(smem_ptr,
|
||||
cast_into_bf16_and_pack(accum[offset + 0], accum[offset + 1]),
|
||||
cast_into_bf16_and_pack(accum[offset + 2], accum[offset + 3]),
|
||||
|
||||
@@ -36,14 +36,14 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
|
||||
uint32_t kNumStages, uint32_t kNumLastStages,
|
||||
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
|
||||
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
|
||||
GemmType kGemmType>
|
||||
__global__ void __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
|
||||
uint32_t kNumSMs, GemmType kGemmType>
|
||||
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
|
||||
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
|
||||
const __grid_constant__ CUtensorMap tensor_map_a,
|
||||
const __grid_constant__ CUtensorMap tensor_map_b,
|
||||
const __grid_constant__ CUtensorMap tensor_map_d,
|
||||
const __grid_constant__ CUtensorMap tensor_map_sfa) {
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
|
||||
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
|
||||
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
|
||||
// Scaling checks
|
||||
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
|
||||
@@ -77,10 +77,10 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// Prefetch TMA descriptors at the very beginning
|
||||
if (threadIdx.x == kNumMathThreads) {
|
||||
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_sfa));
|
||||
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
|
||||
cute::prefetch_tma_descriptor(&tensor_map_a);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_b);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_sfa);
|
||||
cute::prefetch_tma_descriptor(&tensor_map_d);
|
||||
}
|
||||
__syncwarp();
|
||||
|
||||
@@ -168,7 +168,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Block scheduler
|
||||
uint32_t m_block_idx, n_block_idx;
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, shape_n, grouped_layout);
|
||||
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, grouped_layout);
|
||||
|
||||
if (threadIdx.x >= kNumMathThreads) {
|
||||
// TMA warp-group for loading data
|
||||
@@ -179,7 +179,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
// Persistently schedule over blocks
|
||||
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
|
||||
|
||||
// Assign TMA multicast number into A and B
|
||||
@@ -278,8 +278,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
|
||||
|
||||
// Launch MMAs
|
||||
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
|
||||
constexpr bool kSkipComputation = std::is_same_v<decltype(skip_type), SkipComputation>;
|
||||
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr bool kSkipComputation = cute::is_same_v<decltype(skip_type), SkipComputation>;
|
||||
constexpr bool kHasDivisibleStages = cute::is_same_v<decltype(divisible_type), DivisibleK>;
|
||||
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages);
|
||||
|
||||
#pragma unroll
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include <deep_gemm/common/utils.cuh>
|
||||
|
||||
namespace deep_gemm {
|
||||
|
||||
@@ -1,103 +0,0 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
|
||||
* All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef __CUDACC_RTC__
|
||||
|
||||
using int8_t = signed char;
|
||||
using uint8_t = unsigned char;
|
||||
using int16_t = signed short;
|
||||
using uint16_t = unsigned short;
|
||||
using int32_t = signed int;
|
||||
using uint32_t = unsigned int;
|
||||
using int64_t = signed long long;
|
||||
using uint64_t = unsigned long long;
|
||||
using cuuint64_t = unsigned long long;
|
||||
|
||||
#ifndef CU_TENSOR_MAP_NUM_QWORDS
|
||||
#define CU_TENSOR_MAP_NUM_QWORDS 16
|
||||
|
||||
struct CUtensorMap_st {
|
||||
#if defined(__cplusplus) && (__cplusplus >= 201103L)
|
||||
alignas(64)
|
||||
#elif __STDC_VERSION__ >= 201112L
|
||||
_Alignas(64)
|
||||
#endif
|
||||
cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS];
|
||||
};
|
||||
|
||||
using CUtensorMap = CUtensorMap_st;
|
||||
#endif
|
||||
|
||||
namespace std {
|
||||
|
||||
template <class T, T v> struct integral_constant {
|
||||
static constexpr T value = v;
|
||||
|
||||
using value_type = T;
|
||||
using type = integral_constant;
|
||||
|
||||
__device__ constexpr operator value_type() const noexcept { return value; }
|
||||
|
||||
__device__ constexpr value_type operator()() const noexcept { return value; }
|
||||
};
|
||||
|
||||
using false_type = integral_constant<bool, false>;
|
||||
using true_type = integral_constant<bool, true>;
|
||||
|
||||
template <class T, class U> struct is_same : false_type {};
|
||||
|
||||
template <class T> struct is_same<T, T> : true_type {};
|
||||
|
||||
template <class T, class U>
|
||||
inline constexpr bool is_same_v = is_same<T, U>::value;
|
||||
|
||||
namespace index_sequence_impl {
|
||||
|
||||
// Based on https://stackoverflow.com/a/32223343/11717224
|
||||
template <size_t... Ints> struct index_sequence {
|
||||
using type = index_sequence;
|
||||
using value_type = size_t;
|
||||
static constexpr size_t size() noexcept { return sizeof...(Ints); }
|
||||
};
|
||||
|
||||
template <class Sequence1, class Sequence2> struct _merge_and_renumber;
|
||||
|
||||
template <size_t... I1, size_t... I2>
|
||||
struct _merge_and_renumber<index_sequence<I1...>, index_sequence<I2...>>
|
||||
: index_sequence<I1..., (sizeof...(I1) + I2)...> {};
|
||||
|
||||
template <size_t N>
|
||||
struct make_index_sequence
|
||||
: _merge_and_renumber<typename make_index_sequence<N / 2>::type,
|
||||
typename make_index_sequence<N - N / 2>::type> {};
|
||||
|
||||
template <> struct make_index_sequence<0> : index_sequence<> {};
|
||||
template <> struct make_index_sequence<1> : index_sequence<0> {};
|
||||
|
||||
} // namespace index_sequence_impl
|
||||
|
||||
template <size_t... Ns>
|
||||
using index_sequence = index_sequence_impl::index_sequence<Ns...>;
|
||||
|
||||
template <size_t N>
|
||||
using make_index_sequence = index_sequence_impl::make_index_sequence<N>;
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif
|
||||
@@ -46,3 +46,12 @@ def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tenso
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
|
||||
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
|
||||
sf = x_amax / 448.0
|
||||
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
|
||||
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled, sf.squeeze()
|
||||
@@ -7,7 +7,7 @@ cd "$script_dir"
|
||||
ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include
|
||||
ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include
|
||||
|
||||
# Remove old dist file, build, and build
|
||||
# Remove old dist file, build files, and build
|
||||
rm -rf build dist
|
||||
rm -rf *.egg-info
|
||||
python setup.py build
|
||||
|
||||
@@ -3,7 +3,7 @@ original_dir=$(pwd)
|
||||
script_dir=$(realpath "$(dirname "$0")")
|
||||
cd "$script_dir"
|
||||
|
||||
# Remove old dist file, build, and install
|
||||
# Remove old dist file, build files, and install
|
||||
rm -rf build dist
|
||||
rm -rf *.egg-info
|
||||
python setup.py bdist_wheel
|
||||
|
||||
11
setup.py
11
setup.py
@@ -2,12 +2,14 @@ import os
|
||||
import setuptools
|
||||
import shutil
|
||||
import subprocess
|
||||
import torch
|
||||
from setuptools import find_packages
|
||||
from setuptools.command.build_py import build_py
|
||||
from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME
|
||||
|
||||
current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
cxx_flags = ['-std=c++20', '-O3', '-fPIC', '-Wno-psabi']
|
||||
cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations',
|
||||
f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}']
|
||||
sources = ['csrc/python_api.cpp']
|
||||
build_include_dirs = [
|
||||
f'{CUDA_HOME}/include',
|
||||
@@ -15,7 +17,7 @@ build_include_dirs = [
|
||||
'third-party/cutlass/include',
|
||||
'third-party/fmt/include',
|
||||
]
|
||||
build_libraries = ['cuda', 'cudart']
|
||||
build_libraries = ['cuda', 'cudart', 'nvrtc']
|
||||
build_library_dirs = [
|
||||
f'{CUDA_HOME}/lib64',
|
||||
f'{CUDA_HOME}/lib64/stubs'
|
||||
@@ -40,7 +42,7 @@ class CustomBuildPy(build_py):
|
||||
def generate_default_envs(self):
|
||||
code = '# Pre-installed environment variables\n'
|
||||
code += 'persistent_envs = dict()\n'
|
||||
for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_DISABLE_SHORTCUT_CACHE'):
|
||||
for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_CPP_STANDARD'):
|
||||
code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else ''
|
||||
|
||||
with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f:
|
||||
@@ -78,9 +80,6 @@ if __name__ == '__main__':
|
||||
name='deep_gemm',
|
||||
version='2.0.0' + revision,
|
||||
packages=find_packages('.'),
|
||||
install_requires=[
|
||||
'torch>=2.1.0',
|
||||
],
|
||||
package_data={
|
||||
'deep_gemm': [
|
||||
'include/deep_gemm/**/*',
|
||||
|
||||
@@ -14,6 +14,7 @@ class KernelType(enum.Enum):
|
||||
# For SM100 GEMMs
|
||||
Kernel1D1D = 0
|
||||
Kernel1D2D = 1
|
||||
KernelNoSF = 2
|
||||
|
||||
def is_1d1d(self):
|
||||
return self.value == 0
|
||||
@@ -21,6 +22,9 @@ class KernelType(enum.Enum):
|
||||
def is_1d2d(self):
|
||||
return self.value == 1
|
||||
|
||||
def is_nosf(self):
|
||||
return self.value == 2
|
||||
|
||||
|
||||
class MajorTypeAB(enum.Enum):
|
||||
KMajor = 0
|
||||
@@ -44,7 +48,9 @@ def get_ue8m0_usage(kernel_type: KernelType) -> bool:
|
||||
return kernel_type.is_1d1d()
|
||||
|
||||
|
||||
def get_kernel_types() -> tuple:
|
||||
def get_kernel_types(use_bf16: bool = False) -> tuple:
|
||||
if use_bf16:
|
||||
return (KernelType.KernelNoSF, )
|
||||
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)
|
||||
|
||||
|
||||
@@ -61,13 +67,13 @@ def get_major_ab(freeze_a: bool) -> tuple:
|
||||
(MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
|
||||
|
||||
|
||||
def enumerate_normal() -> Generator:
|
||||
for kernel_type in get_kernel_types():
|
||||
def enumerate_normal(use_bf16: bool = False) -> Generator:
|
||||
for kernel_type in get_kernel_types(use_bf16):
|
||||
for m in (128, 4096):
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
|
||||
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048), (129280, 7168)]:
|
||||
for major_a, major_b in get_major_ab(False):
|
||||
for out_dtype in get_out_dtype():
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
|
||||
for accumulate in (False, ) if out_dtype == torch.bfloat16 or not kernel_type.is_1d1d() else (False, True):
|
||||
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
|
||||
|
||||
|
||||
@@ -123,7 +129,7 @@ def enumerate_k_grouped_sf_layout():
|
||||
def generate_normal(m: int, n: int, k: int,
|
||||
major_a: MajorTypeAB, major_b: MajorTypeAB,
|
||||
accumulate: bool, out_dtype: torch.dtype,
|
||||
use_ue8m0: bool):
|
||||
use_ue8m0: bool = False, use_bf16: bool = False):
|
||||
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
|
||||
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
|
||||
d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \
|
||||
@@ -131,6 +137,11 @@ def generate_normal(m: int, n: int, k: int,
|
||||
c = d if accumulate else None
|
||||
ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype)
|
||||
|
||||
if use_bf16:
|
||||
a = a if major_a.is_k_major() else a.T.contiguous().T
|
||||
b = b if major_b.is_k_major() else b.T.contiguous().T
|
||||
return a, b, c, d, ref_d
|
||||
|
||||
a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
|
||||
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
|
||||
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
|
||||
|
||||
@@ -51,10 +51,10 @@ def test_gemm() -> None:
|
||||
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
|
||||
|
||||
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}):'
|
||||
f' launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | '
|
||||
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): '
|
||||
f'launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | '
|
||||
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
|
||||
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
|
||||
f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s')
|
||||
print()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user