From d9c363f86f20b38dd852191d6680b10059fb61d5 Mon Sep 17 00:00:00 2001 From: Ray Wang Date: Sat, 2 Aug 2025 19:52:22 -0700 Subject: [PATCH] Make various updates and fixes: - Add support for legacy CUDA versions; now compatible with CUDA 12.3 and newer - Add support for NVRTC compilation - Other fixes and code refactoring --- CMakeLists.txt | 4 +- README.md | 7 +- build.sh | 12 ++ csrc/jit/cache.hpp | 2 +- csrc/jit/compiler.hpp | 149 +++++++++++++++--- csrc/jit/device_runtime.hpp | 14 +- csrc/jit/handle.hpp | 135 ++++++++++++++++ csrc/jit/kernel_runtime.hpp | 62 +++----- csrc/jit_kernels/heuristics/common.hpp | 16 +- csrc/jit_kernels/heuristics/sm100.hpp | 9 +- .../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp | 24 +-- .../jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp | 16 +- csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp | 15 +- csrc/jit_kernels/impls/smxx_layout.hpp | 57 +------ csrc/python_api.cpp | 44 ++++-- csrc/utils/exception.hpp | 10 ++ csrc/utils/lazy_init.hpp | 27 ++++ csrc/utils/system.hpp | 4 +- deep_gemm/__init__.py | 9 +- .../include/deep_gemm/common/cute_tie.cuh | 48 ++++++ .../include/deep_gemm/common/scheduler.cuh | 23 ++- .../include/deep_gemm/common/sm100_utils.cuh | 2 +- .../include/deep_gemm/common/sm90_utils.cuh | 7 +- deep_gemm/include/deep_gemm/common/types.hpp | 1 + deep_gemm/include/deep_gemm/common/utils.cuh | 9 ++ .../deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh | 36 +++-- .../deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh | 26 +-- .../deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh | 28 ++-- .../include/deep_gemm/impls/smxx_layout.cuh | 2 - deep_gemm/include/deep_gemm/nvrtc_std.cuh | 103 ------------ deep_gemm/utils/math.py | 9 ++ develop.sh | 2 +- install.sh | 2 +- setup.py | 11 +- tests/generators.py | 23 ++- tests/{test_core.py => test_fp8.py} | 6 +- 36 files changed, 592 insertions(+), 362 deletions(-) create mode 100755 build.sh create mode 100644 csrc/jit/handle.hpp create mode 100644 csrc/utils/lazy_init.hpp create mode 100644 deep_gemm/include/deep_gemm/common/cute_tie.cuh delete mode 100644 deep_gemm/include/deep_gemm/nvrtc_std.cuh rename tests/{test_core.py => test_fp8.py} (97%) diff --git a/CMakeLists.txt b/CMakeLists.txt index ab20d62..6f12a96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,8 +18,8 @@ find_package(CUDAToolkit REQUIRED) find_package(pybind11 REQUIRED) find_package(Torch REQUIRED) -set(CMAKE_CXX_STANDARD 20) -set(CMAKE_CUDA_STANDARD 20) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) diff --git a/README.md b/README.md index 9e03fc6..491574c 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - [x] MoE scheduler with TMA multicast compatibility - [x] Fix TMA multicast compatibility for indivisible shapes - [x] Skip useless computation on M -- [ ] NVRTC as a faster compiler +- [x] NVRTC as a faster compiler - [ ] Sanitizer for testing - [x] Weight gradient kernels for dense models - [x] Weight gradient kernels for MoE models @@ -46,8 +46,7 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - Python 3.8 or higher - Compilers with C++20 support - CUDA Toolkit: - - Currently, CUDA 12.8 or higher is required, but support for older versions may be added in the future - - CUDA 12.8 or higher for SM90 + - CUDA 12.3 or higher for SM90 - **We highly recommend 12.9 or higher for the best performance** - CUDA 12.9 or higher for SM100 - PyTorch 2.1 or higher @@ -114,6 +113,8 @@ The library provides some utility functions besides the above kernels: - `deep_gemm.set_num_sms`: set the maximum SM count to use - `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set) +- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio +- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio - `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout - `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size - `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout diff --git a/build.sh b/build.sh new file mode 100755 index 0000000..abdfc40 --- /dev/null +++ b/build.sh @@ -0,0 +1,12 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build files, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel + +# Open users' original directory +cd "$original_dir" diff --git a/csrc/jit/cache.hpp b/csrc/jit/cache.hpp index fde9aab..1e8659f 100644 --- a/csrc/jit/cache.hpp +++ b/csrc/jit/cache.hpp @@ -9,7 +9,7 @@ namespace deep_gemm { class KernelRuntimeCache { - std::unordered_map> cache; + std::unordered_map> cache; public: // TODO: consider cache capacity diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index ea77a16..0e84b48 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -1,14 +1,17 @@ #pragma once #include +#include #include #include +#include #include #include #include "../utils/exception.hpp" #include "../utils/format.hpp" #include "../utils/hash.hpp" +#include "../utils/lazy_init.hpp" #include "../utils/system.hpp" #include "cache.hpp" #include "device_runtime.hpp" @@ -16,10 +19,13 @@ namespace deep_gemm { class Compiler { - std::string library_version; - std::filesystem::path library_root_path; +public: + static std::filesystem::path library_root_path; + static std::filesystem::path library_include_path; + static std::filesystem::path cuda_home; + static std::string library_version; - std::string get_library_version() const { + static std::string get_library_version() { std::stringstream ss; for (const auto& f: collect_files(library_include_path / "deep_gemm")) { std::ifstream in(f, std::ios::binary); @@ -28,16 +34,23 @@ class Compiler { return get_hex_digest(ss.str()); } -public: + static void prepare_init(const std::string& library_root_path, + const std::string& cuda_home_path_by_torch) { + Compiler::library_root_path = library_root_path; + Compiler::library_include_path = Compiler::library_root_path / "include"; + Compiler::cuda_home = cuda_home_path_by_torch; + Compiler::library_version = get_library_version(); + } + std::string signature, flags; - std::filesystem::path library_include_path; std::filesystem::path cache_dir_path; - explicit Compiler(const std::filesystem::path& library_root_path) { - // Static library paths - this->library_root_path = library_root_path; - this->library_include_path = library_root_path / "include"; - this->library_version = get_library_version(); + Compiler() { + // Check `prepare_init` + DG_HOST_ASSERT(not library_root_path.empty()); + DG_HOST_ASSERT(not library_include_path.empty()); + DG_HOST_ASSERT(not cuda_home.empty()); + DG_HOST_ASSERT(not library_version.empty()); // Cache settings cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; @@ -46,10 +59,11 @@ public: // The compiler flags applied to all derived compilers signature = "unknown-compiler"; - std::string ptxas_flags = "--ptxas-options=--register-usage-level=10"; - if (get_env("DG_JIT_PTXAS_VERBOSE", 0)) - ptxas_flags += ",--verbose"; - flags = fmt::format("-std=c++20 --diag-suppress=39,161,174,177,186,940 {}", ptxas_flags); + flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 " + "--ptxas-options=--register-usage-level=10", + get_env("DG_JIT_CPP_STANDARD", 20)); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) + flags += " --ptxas-options=--verbose"; } virtual ~Compiler() = default; @@ -102,6 +116,11 @@ public: virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0; }; +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version); + class NVCCCompiler final: public Compiler { std::filesystem::path nvcc_path; @@ -125,11 +144,9 @@ class NVCCCompiler final: public Compiler { } public: - NVCCCompiler(const std::filesystem::path& library_root_path, - const std::filesystem::path& cuda_home_path_by_torch): - Compiler(library_root_path) { + NVCCCompiler() { // Override the compiler signature - nvcc_path = cuda_home_path_by_torch / "bin" / "nvcc"; + nvcc_path = cuda_home / "bin" / "nvcc"; if (const auto& env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) nvcc_path = env_nvcc_path; const auto& [nvcc_major, nvcc_minor] = get_nvcc_version(); @@ -150,10 +167,10 @@ public: // Compile const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) - printf("Running NVCC command: %s", command.c_str()); + printf("Running NVCC command: %s\n", command.c_str()); const auto& [return_code, output] = call_external_command(command); if (return_code != 0) { - printf("NVCC compilation failed: %s", output.c_str()); + printf("NVCC compilation failed: %s\n", output.c_str()); DG_HOST_ASSERT(false and "NVCC compilation failed"); } @@ -163,6 +180,96 @@ public: } }; -static std::shared_ptr compiler = nullptr; +class NVRTCCompiler final: public Compiler { +public: + NVRTCCompiler() { + // Override the compiler signature + int major, minor; + DG_NVRTC_CHECK(nvrtcVersion(&major, &minor)); + signature = fmt::format("NVRTC{}.{}", major, minor); + + // Build include directories list + std::string include_dirs; + include_dirs += fmt::format("-I{} ", library_include_path.string()); + include_dirs += fmt::format("-I{} ", (cuda_home / "include").string()); + + // Add PCH support for version 12.8 and above + // NOTES: PCH is vital for compilation speed + std::string pch_flags; + if (major > 12 or (major == 12 and minor >= 8)) { + pch_flags = "--pch "; + if (get_env("DG_JIT_DEBUG", 0)) + pch_flags += "--pch-verbose=true "; + } + + // Override the compiler flags + flags = fmt::format("{} {}--gpu-architecture=sm_{}a -default-device {}", + flags, include_dirs, device_runtime->get_arch(), pch_flags); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + // Write the code into the cache directory + const auto& code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Parse compilation options + std::istringstream iss(flags); + std::vector options; + std::string option; + while (iss >> option) + options.push_back(option); + + // Convert to C-style string array for NVRTC + std::vector option_cstrs; + for (const auto& opt: options) + option_cstrs.push_back(opt.c_str()); + + // Print compiler command if requested + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) { + printf("Compiling JIT runtime with NVRTC options: "); + for (const auto& opt: options) + printf("%s ", opt.c_str()); + printf("\n"); + } + + // Create NVRTC program and compile + nvrtcProgram program; + DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr)); + const auto& compile_result = nvrtcCompileProgram(program, static_cast(option_cstrs.size()), option_cstrs.data()); + + // Get and print compiler log + size_t log_size; + DG_NVRTC_CHECK(nvrtcGetProgramLogSize(program, &log_size)); + if (get_env("DG_JIT_DEBUG", 0) or compile_result != NVRTC_SUCCESS) { + if (compile_result != NVRTC_SUCCESS) + DG_HOST_ASSERT(log_size > 1); + if (log_size > 1) { + std::string compilation_log(log_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetProgramLog(program, compilation_log.data())); + printf("NVRTC log: %s\n", compilation_log.c_str()); + } + } + + // Get CUBIN size and data + size_t cubin_size; + DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size)); + std::string cubin_data(cubin_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetCUBIN(program, cubin_data.data())); + + // Write into the file system + put(cubin_path, cubin_data); + + // Cleanup + DG_NVRTC_CHECK(nvrtcDestroyProgram(&program)); + } +}; + +static auto compiler = LazyInit([]() -> std::shared_ptr { + if (get_env("DG_JIT_USE_NVRTC", 0)) { + return std::make_shared(); + } else { + return std::make_shared(); + } +}); } // namespace deep_gemm diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index c3237da..7cd1882 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -3,11 +3,12 @@ #include #include "../utils/exception.hpp" +#include "../utils/lazy_init.hpp" namespace deep_gemm { class DeviceRuntime { - int num_sms = 0; + int num_sms = 0, tc_util = 0; std::shared_ptr cached_prop; public: @@ -43,8 +44,17 @@ public: num_sms = get_prop()->multiProcessorCount; return num_sms; } + + void set_tc_util(const int& new_tc_util) { + DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100); + tc_util = new_tc_util; + } + + int get_tc_util() const { + return tc_util == 0 ? 100 : tc_util; + } }; -static auto device_runtime = std::make_shared(); +static auto device_runtime = LazyInit([](){ return std::make_shared(); }); } // namespace deep_gemm diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp new file mode 100644 index 0000000..754b299 --- /dev/null +++ b/csrc/jit/handle.hpp @@ -0,0 +1,135 @@ +#pragma once + +#include +#include +#include + +#include "../utils/exception.hpp" + +namespace deep_gemm { + +#if CUDART_VERSION >= 12080 or defined(DG_JIT_USE_DRIVER_API) + +// Use CUDA runtime API +using LibraryHandle = cudaLibrary_t; +using KernelHandle = cudaKernel_t; +using LaunchConfigHandle = cudaLaunchConfig_t; +using LaunchAttrHandle = cudaLaunchAttribute; + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_RUNTIME_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel{}; + DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, func_name.c_str())); + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { + const auto& error = cudaLibraryUnload(library); + DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + if (smem_size > 0) + DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + LaunchConfigHandle config; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + config.numAttrs = 0; + config.attrs = nullptr; + + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attr; + if (cluster_dim > 1) { + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; + config.attrs = &attr; + config.numAttrs = 1; + } + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return cudaLaunchKernelExC(&config, kernel, ptr_args); +} + +#else + +// Use CUDA driver API +using LibraryHandle = CUmodule; +using KernelHandle = CUfunction; +using LaunchConfigHandle = CUlaunchConfig; +using LaunchAttrHandle = CUlaunchAttribute; + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel; + DG_CUDA_DRIVER_CHECK(cuModuleLoad(&library, cubin_path.c_str())); + DG_CUDA_DRIVER_CHECK(cuModuleGetFunction(&kernel, library, func_name.c_str())); + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { + const auto& error = cuModuleUnload(library); + DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) { + if (smem_size > 0) + DG_CUDA_DRIVER_CHECK(cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); + + LaunchConfigHandle config; + config.gridDimX = grid_dim.x; + config.gridDimY = grid_dim.y; + config.gridDimZ = grid_dim.z; + config.blockDimX = block_dim.x; + config.blockDimY = block_dim.y; + config.blockDimZ = block_dim.z; + config.sharedMemBytes = smem_size; + config.hStream = stream; + config.numAttrs = 0; + config.attrs = nullptr; + + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attr; + if (cluster_dim > 1) { + attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attr.value.clusterDim.x = cluster_dim; + attr.value.clusterDim.y = 1; + attr.value.clusterDim.z = 1; + config.attrs = &attr; + config.numAttrs = 1; + } + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return cuLaunchKernelEx(&config, kernel, ptr_args, nullptr); +} + +#endif + +} // namespace deep_gemm diff --git a/csrc/jit/kernel_runtime.hpp b/csrc/jit/kernel_runtime.hpp index ac95f99..5a6022c 100644 --- a/csrc/jit/kernel_runtime.hpp +++ b/csrc/jit/kernel_runtime.hpp @@ -1,12 +1,10 @@ #pragma once -#include -#include - #include "../utils/exception.hpp" #include "../utils/format.hpp" #include "../utils/system.hpp" #include "device_runtime.hpp" +#include "handle.hpp" namespace deep_gemm { @@ -23,19 +21,17 @@ struct LaunchArgs { grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {} }; -template -concept HasLaunchArgs = requires (const T& t) { - { t.launch_args } -> std::convertible_to; -}; - class KernelRuntime final { public: static std::filesystem::path cuda_home; - cudaLibrary_t library; - cudaKernel_t kernel; + LibraryHandle library; + KernelHandle kernel; explicit KernelRuntime(const std::filesystem::path& dir_path) { + // Check `prepare_init` + DG_HOST_ASSERT(not cuda_home.empty()); + // NOLINT(*-pro-type-member-init) const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump"; const auto& cubin_path = dir_path / "kernel.cubin"; @@ -50,7 +46,8 @@ public: std::istringstream iss(symbols); std::vector symbol_names; for (std::string line; std::getline(iss, line); ) { - if (line.find("STT_FUNC") == 0 and std::ranges::none_of(illegal_names, [&](const auto& name) { return line.find(name) != std::string::npos; })) { + if (line.find("STT_FUNC") == 0 and std::none_of(illegal_names.begin(), illegal_names.end(), + [&](const auto& name) { return line.find(name) != std::string::npos; })) { const auto& last_space = line.rfind(' '); symbol_names.push_back(line.substr(last_space + 1)); } @@ -64,11 +61,10 @@ public: // Load from the library DG_HOST_ASSERT(symbol_names.size() == 1); - DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); - DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, symbol_names[0].c_str())); + kernel = load_kernel(cubin_path, symbol_names[0], &library); } - static void set_cuda_home(const std::string& cuda_home_path_by_torch) { + static void prepare_init(const std::string& cuda_home_path_by_torch) { cuda_home = cuda_home_path_by_torch; } @@ -78,18 +74,16 @@ public: } ~KernelRuntime() noexcept(false) { - const auto& error = cudaLibraryUnload(library); - DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading); + unload_library(library); } }; -// Declare after defining -decltype(KernelRuntime::cuda_home) KernelRuntime::cuda_home; +DG_DECLARE_STATIC_VAR_IN_CLASS(KernelRuntime, cuda_home); template class LaunchRuntime { public: - template requires HasLaunchArgs + template static std::string generate(const Args& args) { const auto& code = Derived::generate_impl(args); if (get_env("DG_JIT_DEBUG", 0)) @@ -97,34 +91,18 @@ public: return code; } - template requires HasLaunchArgs + template static void launch(const std::shared_ptr& kernel_runtime, const Args& args) { const auto& kernel = kernel_runtime->kernel; const auto& stream = at::cuda::getCurrentCUDAStream(); const LaunchArgs& launch_args = args.launch_args; - // Set dynamic shared memory size - if (launch_args.smem_size > 0) - DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, launch_args.smem_size)); - - // Launch config - cudaLaunchConfig_t config; - config.gridDim = {static_cast(launch_args.grid_dim.first), - static_cast(launch_args.grid_dim.second), - 1}; - config.blockDim = {static_cast(launch_args.num_threads), 1, 1}; - config.dynamicSmemBytes = launch_args.smem_size; - config.stream = stream; - config.numAttrs = 0; - - // Clusters - cudaLaunchAttribute attr; - if (launch_args.cluster_dim > 1) { - attr.id = cudaLaunchAttributeClusterDimension; - attr.val.clusterDim = {static_cast(launch_args.cluster_dim), 1, 1}; - config.attrs = &attr; - config.numAttrs = 1; - } + const dim3& grid_dim = {static_cast(launch_args.grid_dim.first), + static_cast(launch_args.grid_dim.second), + 1}; + const dim3& block_dim = {static_cast(launch_args.num_threads), 1, 1}; + auto config = construct_launch_config(kernel, stream, launch_args.smem_size, + grid_dim, block_dim, launch_args.cluster_dim); // Launch in the derived class if (get_env("DG_JIT_DEBUG")) { diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index b5a8b61..7b8318d 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -62,8 +62,9 @@ struct GemmConfig { int block_m, block_n, block_k; int num_stages, num_last_stages; - // Runtime configs + // Templated device configs int num_sms; + int tc_util; // Structured configs MulticastConfig multicast_config; @@ -265,30 +266,35 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k .num_stages = best_num_stages, .num_last_stages = ceil_div(k, block_k) % best_num_stages, .num_sms = num_min_sms, + .tc_util = device_runtime->get_tc_util(), .multicast_config = best_multicast_config, // ReSharper disable once CppLocalVariableMightNotBeInitialized .smem_config = best_smem_config, .thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n) }; + // Only SM100 BF16 kernels support tensor core control + if (config.tc_util < 100) + DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and ab_dtype == torch::kBFloat16); + // Print configs for the first time if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b, ab_dtype, cd_dtype, with_accumulation, num_sms); static std::set printed; - if (not printed.contains(key)) { - printf("Gemm type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " + if (printed.count(key) == 0) { + printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " "A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, " "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " - "swizzle B: %d, swizzle CD: %d, threads: %d\n", + "swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n", static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, static_cast(major_a), static_cast(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype), static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, static_cast(best_multicast_config.is_multicast_on_a), best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode, - best_smem_config.swizzle_cd_mode, config.thread_config.num_threads); + best_smem_config.swizzle_cd_mode, config.num_sms, config.thread_config.num_threads, config.tc_util); printed.insert(key); } } diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index 722c3d1..e26b69f 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -44,6 +44,11 @@ struct SM100ArchSpec { const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, const int& block_m, const int& block_n) { + // TODO: consider more carefully for BF16 GEMMs + // 2SM BF16 UMMA does not support `N % 32 != 0` + if (ab_dtype == torch::kBFloat16 and block_n % 32 != 0) + return false; + // Layout A/D does not support `block_m == 64` and `block_n % 16 != 0` if (block_m == 64 or block_n % 16 != 0) return false; @@ -68,7 +73,7 @@ struct SM100ArchSpec { // NOTES: when B is MN-major, we restrict `block_n` to multiples of 64, // since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA - return major_b == cute::UMMA::Major::K or block_n % 64 == 0; + return major_b == cute::UMMA::Major::K or (block_n * c10::elementSize(ab_dtype)) % 64 == 0; } static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, @@ -93,7 +98,7 @@ struct SM100ArchSpec { static ThreadConfig get_thread_config(const KernelType& kernel_type, const int& block_m, const int& block_n) { - return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D1D ? 128 : block_m); + return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D2D ? block_m : 128); } static int get_smem_cd_size(const KernelType& kernel_type, diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index fe8887e..d4e573b 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -32,13 +32,6 @@ public: static std::string generate_impl(const Args& args) { return fmt::format(R"( -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - #include using namespace deep_gemm; @@ -54,7 +47,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, - {}, {} + {}, {}, {} >); }}; )", @@ -66,14 +59,13 @@ static void __instantiate_kernel() {{ args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - to_string(args.gemm_config.gemm_type), - args.gemm_config.with_accumulation, - to_string(args.gemm_config.cd_dtype)); + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype)); } - static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy - DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.grouped_layout, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, args.tensor_map_sfa, args.tensor_map_sfb, @@ -286,7 +278,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& const auto& num_groups = static_cast(ks.size()); // Get config using max K for better performance - const auto& max_k = *std::ranges::max_element(ks); + const auto& max_k = *std::max_element(ks.begin(), ks.end()); const auto& config = get_best_config( GemmType::KGroupedContiguous, KernelType::Kernel1D1D, m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, @@ -316,9 +308,9 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& static_cast(cd.stride(1)), num_groups, config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512, - config.block_m, config.block_k, num_groups, 0); + config.block_m, config.block_k, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512, - config.block_n, config.block_k, num_groups, 0); + config.block_n, config.block_k, 1, 0); // Duplicate the accumulator if necessary if (c.has_value()) { diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp index 02478a0..c33a450 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp @@ -30,13 +30,6 @@ public: static std::string generate_impl(const Args& args) { return fmt::format(R"( -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - #include using namespace deep_gemm; @@ -51,6 +44,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, + {}, {}, {} >); }}; @@ -63,13 +57,13 @@ static void __instantiate_kernel() {{ args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - to_string(args.gemm_config.gemm_type), - to_string(args.gemm_config.cd_dtype)); + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), to_string(args.gemm_config.cd_dtype)); } - static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy - DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sfb, args.grouped_layout, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index 2909ef3..088bf1a 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -29,13 +29,6 @@ public: static std::string generate_impl(const Args& args) { return fmt::format(R"( -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - #include using namespace deep_gemm; @@ -49,7 +42,7 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, {}, - {} + {}, {} >); }}; )", @@ -61,12 +54,12 @@ static void __instantiate_kernel() {{ args.gemm_config.num_stages, args.gemm_config.num_last_stages, args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, - to_string(args.gemm_config.gemm_type)); + args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type)); } - static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { // TODO: optimize `args` copy - DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sfb, args.grouped_layout, args.m, args.n, args.k, args.tensor_map_a, args.tensor_map_b, diff --git a/csrc/jit_kernels/impls/smxx_layout.hpp b/csrc/jit_kernels/impls/smxx_layout.hpp index eda8c1b..49021d3 100644 --- a/csrc/jit_kernels/impls/smxx_layout.hpp +++ b/csrc/jit_kernels/impls/smxx_layout.hpp @@ -22,13 +22,6 @@ public: static std::string generate_impl(const Args& args) { return fmt::format(R"( -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - #include using namespace deep_gemm; @@ -41,8 +34,8 @@ static void __instantiate_kernel() {{ )", args.launch_args.num_threads, args.block_mn, args.sf_k); } - static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { - DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast(args.mn))); + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); } }; @@ -58,13 +51,6 @@ public: static std::string generate_impl(const Args& args) { return fmt::format(R"( -#ifdef __CUDACC_RTC__ -#include -#else -#include -#include -#endif - #include using namespace deep_gemm; @@ -77,8 +63,8 @@ static void __instantiate_kernel() {{ )", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k); } - static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) { - DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k)); } }; @@ -108,43 +94,16 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf; } -static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) { - const auto& sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf; - - // First, convert into UE8M0 `uint8_t` - const auto& ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8); - - // Second, make padded packed tensors - const auto& [num_groups, mn, k] = get_shape<3>(sf_reshaped); - const auto& aligned_mn = get_tma_aligned_size(mn, 4); - const auto& aligned_k = align(k, 4); - - const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8); - auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options); - // ReSharper disable once CppExpressionWithoutSideEffects - padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor); - padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4}); - - // Finally, transpose - auto out = torch::empty_strided({num_groups, aligned_mn, aligned_k / 4}, - {aligned_mn * (aligned_k / 4), 1, aligned_mn}, - at::TensorOptions().device(sf.device()).dtype(torch::kInt32)); - out = out.copy_(padded).slice(1, 0, mn); - return (sf.dim() == 2) ? out.squeeze(0) : out; -} - static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) { const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); const auto& packed_sf_k = ceil_div(sf_k, 4); const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k}, {packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn}, at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); + DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0); + // Launch the kernel if (batched_sf.is_contiguous()) { - // Fallback to slow PyTorch impl for non-supported cases - if ((mn * sf_k) % 4 != 0 and num_groups > 1) - return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); - constexpr int block_mn = 48; constexpr int num_threads = 512; const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = { @@ -160,10 +119,6 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); } else { - // Fallback to slow PyTorch impl for non-supported cases - if (mn % 4 != 0 or num_groups > 1) - return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); - DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1); DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn); diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp index e1e916f..134f272 100644 --- a/csrc/python_api.cpp +++ b/csrc/python_api.cpp @@ -5,10 +5,10 @@ #include "jit/device_runtime.hpp" #include "utils/layout.hpp" -#include "jit_kernels/impls/smxx_layout.hpp" #include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" #include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp" #include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp" +#include "jit_kernels/impls/smxx_layout.hpp" #ifndef TORCH_EXTENSION_NAME #define TORCH_EXTENSION_NAME deep_gemm_cpp @@ -17,8 +17,8 @@ namespace deep_gemm { torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, const int& mn, const int& k, - const std::optional& num_groups, const std::tuple& recipe, + const std::optional& num_groups, const bool& is_sfa, const bool& disable_ue8m0_cast) { const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe); @@ -121,8 +121,8 @@ void fp8_gemm_nt(const std::pair& a, // Transform SFA and SFB into compute-required layout if (not recipe.has_value()) recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast); - const auto& sfb = transform_sf_into_required_layout(b.second, n, k, std::nullopt, recipe.value(), false, disable_ue8m0_cast); + const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); + const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast); // Dispatch into different implements const auto& arch_major = device_runtime->get_arch_major(); @@ -133,7 +133,7 @@ void fp8_gemm_nt(const std::pair& a, } else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) { sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); } else { - DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types"); + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } } @@ -208,8 +208,8 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pairget_arch_major(); @@ -223,7 +223,7 @@ void m_grouped_fp8_gemm_nt_contiguous(const std::pair& // Transform scaling factors if (not recipe.has_value()) recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - const auto& sfa = transform_sf_into_required_layout(a.second, m, k, num_groups, recipe.value(), true, disable_ue8m0_cast); - const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast); + const auto& sfa = transform_sf_into_required_layout(a.second, m, k, recipe.value(), num_groups, true, disable_ue8m0_cast); + const auto& sfb = transform_sf_into_required_layout(b.second, n, k, recipe.value(), num_groups, false, disable_ue8m0_cast); // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); @@ -286,7 +286,7 @@ void fp8_m_grouped_gemm_nt_masked(const std::pair& sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); } else { - DG_HOST_UNREACHABLE("Unsupported kernel or scaling factor types"); + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } } @@ -339,18 +339,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepGEMM C++ library"; // Runtime + m.def("set_num_sms", [&](const int& new_num_sms) { + device_runtime->set_num_sms(new_num_sms); + }); m.def("get_num_sms", [&]() { return device_runtime->get_num_sms(); }); - m.def("set_num_sms", [&](const int& new_num_sms) { - device_runtime->set_num_sms(new_num_sms); + m.def("set_tc_util", [&](const int& new_tc_util) { + device_runtime->set_tc_util(new_tc_util); + }); + m.def("get_tc_util", [&]() { + return device_runtime->get_tc_util(); }); // JIT m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) { - DG_HOST_ASSERT(get_env("DG_JIT_USE_NVRTC", 0) == 0 and "Currently only support NVCC"); - compiler = std::make_shared(library_root_path, cuda_home_path_by_torch); - KernelRuntime::set_cuda_home(cuda_home_path_by_torch); + Compiler::prepare_init(library_root_path, cuda_home_path_by_torch); + KernelRuntime::prepare_init(cuda_home_path_by_torch); }); // Stable kernel APIs with automatic arch/layout dispatch @@ -391,7 +396,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("ks_tensor"), py::arg("c") = std::nullopt, py::arg("recipe") = std::make_tuple(1, 1, 128), py::arg("compiled_dims") = "mn"); - m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout); + + // Layout kernels + m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, + py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"), + py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false, + py::arg("disable_ue8m0_cast") = false); // Raw kernels or functions m.def("get_tma_aligned_size", &get_tma_aligned_size); diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp index 493e480..10dedc0 100644 --- a/csrc/utils/exception.hpp +++ b/csrc/utils/exception.hpp @@ -35,6 +35,16 @@ do { \ #define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason)) #endif +#ifndef DG_NVRTC_CHECK +#define DG_NVRTC_CHECK(cmd) \ +do { \ + const auto& e = (cmd); \ + if (e != NVRTC_SUCCESS) { \ + throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \ + } \ +} while (0) +#endif + #ifndef DG_CUDA_DRIVER_CHECK #define DG_CUDA_DRIVER_CHECK(cmd) \ do { \ diff --git a/csrc/utils/lazy_init.hpp b/csrc/utils/lazy_init.hpp new file mode 100644 index 0000000..386b1b4 --- /dev/null +++ b/csrc/utils/lazy_init.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#define DG_DECLARE_STATIC_VAR_IN_CLASS(cls, name) decltype(cls::name) cls::name + +namespace deep_gemm { + +template +class LazyInit { +public: + explicit LazyInit(std::function()> factory) + : factory(std::move(factory)) {} + + T* operator -> () { + if (ptr == nullptr) + ptr = factory(); + return ptr.get(); + } + +private: + std::shared_ptr ptr; + std::function()> factory; +}; + +} // namespace deep_gemm diff --git a/csrc/utils/system.hpp b/csrc/utils/system.hpp index 4835640..91dee12 100644 --- a/csrc/utils/system.hpp +++ b/csrc/utils/system.hpp @@ -38,8 +38,8 @@ static std::tuple call_external_command(std::string command) { std::string output; while (fgets(buffer.data(), buffer.size(), pipe.get())) output += buffer.data(); - const auto exit_code = pclose(pipe.release()); - return {WEXITSTATUS(exit_code), output}; + const auto& exit_code = WEXITSTATUS(pclose(pipe.release())); + return {exit_code, output}; } static std::vector collect_files(const std::filesystem::path& root) { diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 17e7a33..e546a30 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -22,17 +22,22 @@ deep_gemm_cpp.init( # Configs from deep_gemm_cpp import ( set_num_sms, - get_num_sms + get_num_sms, + set_tc_util, + get_tc_util, ) # Kernels from deep_gemm_cpp import ( + # FP8 GEMMs fp8_gemm_nt, fp8_gemm_nn, fp8_gemm_tn, fp8_gemm_tt, m_grouped_fp8_gemm_nt_contiguous, m_grouped_fp8_gemm_nn_contiguous, fp8_m_grouped_gemm_nt_masked, - k_grouped_fp8_gemm_tn_contiguous + k_grouped_fp8_gemm_tn_contiguous, + # Layout kernels + transform_sf_into_required_layout ) # Some utils diff --git a/deep_gemm/include/deep_gemm/common/cute_tie.cuh b/deep_gemm/include/deep_gemm/common/cute_tie.cuh new file mode 100644 index 0000000..cd2aace --- /dev/null +++ b/deep_gemm/include/deep_gemm/common/cute_tie.cuh @@ -0,0 +1,48 @@ +#pragma once + +namespace cute { + +struct ignore_t { + template + constexpr const ignore_t& operator=(T&&) const noexcept { + return *this; + } +}; + +inline constexpr ignore_t ignore{}; + +} // namespace cute + +#define CUTE_TIE_CONCAT_IMPL(A, B) A##B +#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B) + +#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define CUTE_TIE_COUNT_ARGS(...) \ + CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get(TUPLE) +#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get(TUPLE) + +#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1); +#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2); +#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); +#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); +#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5); + +#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_DECL, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ) + +#define CUTE_TIE(TUPLE_EXPR, ...) \ + do { \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_ASSIGN, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ); \ + } while (0) diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index d114381..bada914 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -11,14 +11,28 @@ enum class KGroupedIndexType { SF_K, }; +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto& candidate: {8u, 16u}) { + const auto& usage = isMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + #pragma clang diagnostic push #pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" template + uint32_t kNumSMs, + uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group()> struct Scheduler { int current_iter = -1; @@ -88,6 +102,7 @@ struct Scheduler { #endif // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N if constexpr (kIsMulticastOnA) { m_block_idx = in_group_idx / num_blocks_in_group; n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; @@ -103,7 +118,7 @@ struct Scheduler { if constexpr (kGemmType == GemmType::Normal) { return block_idx * block_size; } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { - const auto offset = kWithGroupOffset ? std::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; return offset * shape_dim + block_idx * block_size; } else if constexpr (kGemmType == GemmType::MGroupedMasked) { const auto offset = kWithGroupOffset ? current_group_idx : 0; @@ -123,7 +138,7 @@ struct Scheduler { } __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { - const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x; + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; if constexpr (kGemmType == GemmType::MGroupedMasked) { while (true) { diff --git a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh index 2016a79..b208302 100644 --- a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -101,7 +101,7 @@ constexpr uint32_t get_umma_desc_stride_k() { template __device__ __forceinline__ uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { - return base + ((offset + k_idx * get_umma_desc_stride_k()) >> 4u); + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); } template diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index e016063..879abda 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -10,13 +9,13 @@ template struct FP8MMA { template - __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence) { + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { using namespace cute::SM90::GMMA; MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); } __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { - call_fma_impl(desc_a, desc_b, d, scale_d, std::make_index_sequence{}); + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); } static constexpr int M = 64; @@ -139,7 +138,7 @@ __device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout __device__ __forceinline__ void tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr, - const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast) { + const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast = 1) { constexpr auto cache_hint = static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL); if (num_tma_multicast == 1) { cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1); diff --git a/deep_gemm/include/deep_gemm/common/types.hpp b/deep_gemm/include/deep_gemm/common/types.hpp index 7e87953..23e7342 100644 --- a/deep_gemm/include/deep_gemm/common/types.hpp +++ b/deep_gemm/include/deep_gemm/common/types.hpp @@ -12,6 +12,7 @@ enum class GemmType { enum class KernelType { Kernel1D1D = 0, Kernel1D2D = 1, + KernelNoSF = 2 }; } // namespace deep_gemm diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index a4ab6a3..7851327 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -2,6 +2,11 @@ #include #include +#include +#include +#include + +#include "cute_tie.cuh" #ifdef __CLION_IDE__ @@ -135,4 +140,8 @@ __device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { return *reinterpret_cast(&bf16x2); } +__device__ __forceinline__ void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + } // namespace `deep_gemm` diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index 360719a..85c01ab 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -20,22 +20,23 @@ template __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d1d_impl(int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_sfa, - const __grid_constant__ CUtensorMap tensor_map_sfb, - const __grid_constant__ CUtensorMap tensor_map_c, - const __grid_constant__ CUtensorMap tensor_map_d) { + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_c, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; // GEMM with accumulation must have FP32 output if constexpr (kWithAccumulation) - DG_STATIC_ASSERT(std::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); // Configs constexpr uint32_t LAYOUT_AD_M = 128; @@ -63,7 +64,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // 2-CTA MMA constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t STORE_BLOCK_M = std::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); @@ -95,6 +96,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Prefetch TMA descriptors at the very beginning if (threadIdx.x == 0) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_sfa); @@ -173,7 +175,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); // For pipeline unrolling struct DivisibleK {}; @@ -207,7 +209,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { - constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; #pragma unroll @@ -329,7 +331,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Launch MMAs launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) { - constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; #pragma unroll @@ -342,7 +344,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad; if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { - using cute_utccp_t = std::conditional_t; // SFA and SFB copy @@ -363,7 +365,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, __syncwarp(); // Issue UMMA in the leader CTA - using cute_mma_t = std::conditional_t, cute::SM100_MMA_MXF8F6F4_2x1SM_SS; + constexpr bool kHasDivisibleStages = cute::is_same_v; const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages; #pragma unroll @@ -530,7 +532,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Load from tensor memory, store into shared memory uint32_t values[kNumElemsPerBankGroup]; - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { // For FP32 output, read and store DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, @@ -539,7 +541,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, st_shared(smem_ptr, values[0], values[1], values[2], values[3]); } else { // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v, "Invalid type"); + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7]); @@ -564,7 +566,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cute::tma_store_fence(); cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync(); if (epilogue_thread_idx == 0) { - using cute_tma_t = std::conditional_t; cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); cute::tma_store_arrive(); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh index dcfeed9..a78a7b1 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh @@ -21,14 +21,15 @@ template __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_d, - const __grid_constant__ CUtensorMap tensor_map_sfa) { + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) using Barrier = cutlass::arch::ClusterTransactionBarrier; @@ -61,7 +62,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // 2-CTA MMA constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); - constexpr uint32_t STORE_BLOCK_M = std::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D"); @@ -87,6 +88,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Prefetch TMA descriptors at the very beginning if (threadIdx.x == 0) { + // NOTES: `reinterpret_cast` must be here, or NVRTC will fail cute::prefetch_tma_descriptor(&tensor_map_a); cute::prefetch_tma_descriptor(&tensor_map_b); cute::prefetch_tma_descriptor(&tensor_map_d); @@ -171,7 +173,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); // Register configurations constexpr uint32_t kNumNonEpilogueRegisters = 64; @@ -187,7 +189,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { launch_k_iterations([&](uint32_t k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -276,7 +278,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, while (scheduler.get_next_block(m_block_idx, n_block_idx)) { // Launch MMAs launch_k_iterations([&](uint32_t k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -293,7 +295,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Issue UMMA in the leader CTA if (s < kNumInnerStages) { - using cute_mma_t = std::conditional_t; tcgen05_after_thread_sync(); #pragma unroll @@ -366,7 +368,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Launch promotion float accum[BLOCK_N] = {0}; launch_k_iterations([&](uint32_t k_iter, auto type) { - constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages"); @@ -474,7 +476,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Load from tensor memory, store into shared memory // NOTES: if you want to do accumulation, please notice that you need two accumulation barriers const auto offset = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; - if constexpr (std::is_same_v) { + if constexpr (cute::is_same_v) { // For FP32 output, read and store DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); st_shared(smem_ptr, @@ -484,7 +486,7 @@ sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, *reinterpret_cast(&accum[offset + 3])); } else { // For BF16 output, read, cast and store - DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v, "Invalid type"); + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); st_shared(smem_ptr, cast_into_bf16_and_pack(accum[offset + 0], accum[offset + 1]), cast_into_bf16_and_pack(accum[offset + 2], accum[offset + 3]), diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 6fff025..c78f72d 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -36,14 +36,14 @@ template -__global__ void __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) + uint32_t kNumSMs, GemmType kGemmType> +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, - const __grid_constant__ CUtensorMap tensor_map_a, - const __grid_constant__ CUtensorMap tensor_map_b, - const __grid_constant__ CUtensorMap tensor_map_d, - const __grid_constant__ CUtensorMap tensor_map_sfa) { + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) // Scaling checks DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); @@ -77,10 +77,10 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Prefetch TMA descriptors at the very beginning if (threadIdx.x == kNumMathThreads) { // NOTES: `reinterpret_cast` must be here, or NVRTC will fail - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_a)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_b)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_sfa)); - cute::prefetch_tma_descriptor(reinterpret_cast(&tensor_map_d)); + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); } __syncwarp(); @@ -168,7 +168,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Block scheduler uint32_t m_block_idx, n_block_idx; - auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); + auto scheduler = Scheduler(shape_m, shape_n, grouped_layout); if (threadIdx.x >= kNumMathThreads) { // TMA warp-group for loading data @@ -179,7 +179,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Persistently schedule over blocks while (scheduler.get_next_block(m_block_idx, n_block_idx)) { launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) { - constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages; // Assign TMA multicast number into A and B @@ -278,8 +278,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Launch MMAs launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) { - constexpr bool kSkipComputation = std::is_same_v; - constexpr bool kHasDivisibleStages = std::is_same_v; + constexpr bool kSkipComputation = cute::is_same_v; + constexpr bool kHasDivisibleStages = cute::is_same_v; constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages); #pragma unroll diff --git a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh index 5b979a8..7385f91 100644 --- a/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh +++ b/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -1,7 +1,5 @@ #pragma once -#include - #include namespace deep_gemm { diff --git a/deep_gemm/include/deep_gemm/nvrtc_std.cuh b/deep_gemm/include/deep_gemm/nvrtc_std.cuh deleted file mode 100644 index 00ce734..0000000 --- a/deep_gemm/include/deep_gemm/nvrtc_std.cuh +++ /dev/null @@ -1,103 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. - * All rights reserved. SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef __CUDACC_RTC__ - -using int8_t = signed char; -using uint8_t = unsigned char; -using int16_t = signed short; -using uint16_t = unsigned short; -using int32_t = signed int; -using uint32_t = unsigned int; -using int64_t = signed long long; -using uint64_t = unsigned long long; -using cuuint64_t = unsigned long long; - -#ifndef CU_TENSOR_MAP_NUM_QWORDS -#define CU_TENSOR_MAP_NUM_QWORDS 16 - -struct CUtensorMap_st { -#if defined(__cplusplus) && (__cplusplus >= 201103L) - alignas(64) -#elif __STDC_VERSION__ >= 201112L - _Alignas(64) -#endif - cuuint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; -}; - -using CUtensorMap = CUtensorMap_st; -#endif - -namespace std { - -template struct integral_constant { - static constexpr T value = v; - - using value_type = T; - using type = integral_constant; - - __device__ constexpr operator value_type() const noexcept { return value; } - - __device__ constexpr value_type operator()() const noexcept { return value; } -}; - -using false_type = integral_constant; -using true_type = integral_constant; - -template struct is_same : false_type {}; - -template struct is_same : true_type {}; - -template -inline constexpr bool is_same_v = is_same::value; - -namespace index_sequence_impl { - -// Based on https://stackoverflow.com/a/32223343/11717224 -template struct index_sequence { - using type = index_sequence; - using value_type = size_t; - static constexpr size_t size() noexcept { return sizeof...(Ints); } -}; - -template struct _merge_and_renumber; - -template -struct _merge_and_renumber, index_sequence> - : index_sequence {}; - -template -struct make_index_sequence - : _merge_and_renumber::type, - typename make_index_sequence::type> {}; - -template <> struct make_index_sequence<0> : index_sequence<> {}; -template <> struct make_index_sequence<1> : index_sequence<0> {}; - -} // namespace index_sequence_impl - -template -using index_sequence = index_sequence_impl::index_sequence; - -template -using make_index_sequence = index_sequence_impl::make_index_sequence; - -} // namespace std - -#endif diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py index 884a711..46804e7 100644 --- a/deep_gemm/utils/math.py +++ b/deep_gemm/utils/math.py @@ -46,3 +46,12 @@ def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tenso sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() \ No newline at end of file diff --git a/develop.sh b/develop.sh index 5879861..3a71e24 100755 --- a/develop.sh +++ b/develop.sh @@ -7,7 +7,7 @@ cd "$script_dir" ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include -# Remove old dist file, build, and build +# Remove old dist file, build files, and build rm -rf build dist rm -rf *.egg-info python setup.py build diff --git a/install.sh b/install.sh index 6b675d6..6e5e6f2 100755 --- a/install.sh +++ b/install.sh @@ -3,7 +3,7 @@ original_dir=$(pwd) script_dir=$(realpath "$(dirname "$0")") cd "$script_dir" -# Remove old dist file, build, and install +# Remove old dist file, build files, and install rm -rf build dist rm -rf *.egg-info python setup.py bdist_wheel diff --git a/setup.py b/setup.py index db83734..1c1e618 100644 --- a/setup.py +++ b/setup.py @@ -2,12 +2,14 @@ import os import setuptools import shutil import subprocess +import torch from setuptools import find_packages from setuptools.command.build_py import build_py from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME current_dir = os.path.dirname(os.path.realpath(__file__)) -cxx_flags = ['-std=c++20', '-O3', '-fPIC', '-Wno-psabi'] +cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations', + f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}'] sources = ['csrc/python_api.cpp'] build_include_dirs = [ f'{CUDA_HOME}/include', @@ -15,7 +17,7 @@ build_include_dirs = [ 'third-party/cutlass/include', 'third-party/fmt/include', ] -build_libraries = ['cuda', 'cudart'] +build_libraries = ['cuda', 'cudart', 'nvrtc'] build_library_dirs = [ f'{CUDA_HOME}/lib64', f'{CUDA_HOME}/lib64/stubs' @@ -40,7 +42,7 @@ class CustomBuildPy(build_py): def generate_default_envs(self): code = '# Pre-installed environment variables\n' code += 'persistent_envs = dict()\n' - for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_DISABLE_SHORTCUT_CACHE'): + for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_CPP_STANDARD'): code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else '' with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f: @@ -78,9 +80,6 @@ if __name__ == '__main__': name='deep_gemm', version='2.0.0' + revision, packages=find_packages('.'), - install_requires=[ - 'torch>=2.1.0', - ], package_data={ 'deep_gemm': [ 'include/deep_gemm/**/*', diff --git a/tests/generators.py b/tests/generators.py index a0597ad..21c050a 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -14,6 +14,7 @@ class KernelType(enum.Enum): # For SM100 GEMMs Kernel1D1D = 0 Kernel1D2D = 1 + KernelNoSF = 2 def is_1d1d(self): return self.value == 0 @@ -21,6 +22,9 @@ class KernelType(enum.Enum): def is_1d2d(self): return self.value == 1 + def is_nosf(self): + return self.value == 2 + class MajorTypeAB(enum.Enum): KMajor = 0 @@ -44,7 +48,9 @@ def get_ue8m0_usage(kernel_type: KernelType) -> bool: return kernel_type.is_1d1d() -def get_kernel_types() -> tuple: +def get_kernel_types(use_bf16: bool = False) -> tuple: + if use_bf16: + return (KernelType.KernelNoSF, ) return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D) @@ -61,13 +67,13 @@ def get_major_ab(freeze_a: bool) -> tuple: (MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) -def enumerate_normal() -> Generator: - for kernel_type in get_kernel_types(): +def enumerate_normal(use_bf16: bool = False) -> Generator: + for kernel_type in get_kernel_types(use_bf16): for m in (128, 4096): - for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]: + for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048), (129280, 7168)]: for major_a, major_b in get_major_ab(False): for out_dtype in get_out_dtype(): - for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True): + for accumulate in (False, ) if out_dtype == torch.bfloat16 or not kernel_type.is_1d1d() else (False, True): yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype @@ -123,7 +129,7 @@ def enumerate_k_grouped_sf_layout(): def generate_normal(m: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, accumulate: bool, out_dtype: torch.dtype, - use_ue8m0: bool): + use_ue8m0: bool = False, use_bf16: bool = False): a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ @@ -131,6 +137,11 @@ def generate_normal(m: int, n: int, k: int, c = d if accumulate else None ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) + if use_bf16: + a = a if major_a.is_k_major() else a.T.contiguous().T + b = b if major_b.is_k_major() else b.T.contiguous().T + return a, b, c, d, ref_d + a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1]) b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1]) diff --git a/tests/test_core.py b/tests/test_fp8.py similarity index 97% rename from tests/test_core.py rename to tests/test_fp8.py index d9ddc75..8f5aa38 100644 --- a/tests/test_core.py +++ b/tests/test_fp8.py @@ -51,10 +51,10 @@ def test_gemm() -> None: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast) t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}):' - f' launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | ' + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s') print()