Add more GPU architectures support (#112)

* Add more GPU architectures support

* Update layout.py

* Optimize performance, Add SM90 support, Add 1D2D SM100 support

* Add fmtlib submodule at commit 553ec11

---------

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
Ray Wang
2025-07-18 11:32:22 +08:00
committed by GitHub
parent 03d0be3d2d
commit 9da4a23561
67 changed files with 5586 additions and 2965 deletions

5
.gitmodules vendored
View File

@@ -1,3 +1,6 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git
url = git@github.com:NVIDIA/cutlass.git
[submodule "third-party/fmt"]
path = third-party/fmt
url = git@github.com:fmtlib/fmt.git

View File

@@ -1,44 +1,33 @@
# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT
# TODO: add CUDA utils' library via CMake
cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
set(USE_SYSTEM_NVTX on)
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
file(WRITE ${CMAKE_BINARY_DIR}/test_cuda.cu "extern \"C\" __global__ void testKernel() { }")
execute_process(
COMMAND ${CUDA_NVCC_EXECUTABLE} ${CMAKE_CUDA_FLAGS} -gencode arch=compute_90a,code=sm_90a -o ${CMAKE_BINARY_DIR}/test_cuda.o -c ${CMAKE_BINARY_DIR}/test_cuda.cu
RESULT_VARIABLE NVCC_RESULT
OUTPUT_VARIABLE NVCC_OUTPUT
ERROR_VARIABLE NVCC_ERROR_OUTPUT
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}
)
if (NVCC_RESULT EQUAL "0")
set(NVCC_SUPPORTS_SM90 TRUE)
message(STATUS "NVCC supports SM90")
else()
message(STATUS "NVCC does not support SM90")
endif()
if (NVCC_SUPPORTS_SM90)
set(TORCH_CUDA_ARCH_LIST "8.6" CACHE STRING "Add arch tag 90a to NVCC" FORCE)
list(APPEND CUDA_NVCC_FLAGS "-gencode;arch=compute_90a,code=sm_90a")
endif()
find_package(Torch REQUIRED)
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -fPIC -DNDEBUG")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -O3 -std=c++17 -DNDEBUG --ptxas-options=--register-usage-level=10")
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
cuda_add_library(example_gemm STATIC indexing/main.cu)
# The main Python API entrance
pybind11_add_module(deep_gemm_cpp csrc/python_api.cpp)
target_link_libraries(deep_gemm_cpp PRIVATE ${TORCH_LIBRARIES} torch_python cuda)
# Enable kernel code indexing with CMake-based IDEs
cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu)

168
README.md
View File

@@ -1,13 +1,18 @@
# DeepGEMM
DeepGEMM is a library designed for clean and efficient FP8 General Matrix Multiplications (GEMMs) with fine-grained scaling, as proposed in [DeepSeek-V3](https://github.com/deepseek-ai/DeepSeek-V3). It supports both normal and Mix-of-Experts (MoE) grouped GEMMs. Written in CUDA, the library has no compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
Currently, DeepGEMM exclusively supports NVIDIA Hopper tensor cores. To address the imprecise FP8 tensor core accumulation, it employs CUDA-core two-level accumulation (promotion). While it leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only one core kernel function. This makes it a clean and accessible resource for learning Hopper FP8 matrix multiplication and optimization techniques.
DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
## News
- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module.
- NVRTC and post-compilation SASS optimization are all disabled
- NVRTC will be supported later
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.
@@ -16,57 +21,59 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- [x] More correctness tests for grouped-contiguous layout
- [x] Shared memory swizzling for output
- [ ] Larger block size on N (up to 256)
- [x] MoE scheduler with TMA multicast compatibility
- [x] Fix TMA multicast compatibility for indivisible shapes
- [x] Skip useless computation on M
- [x] NVRTC as a faster compiler
- [ ] Stolen JIT cache
- [ ] NVRTC as a faster compiler
- [ ] Sanitizer for testing
- [x] Weight gradient kernels for dense models
- [x] Weight gradient kernels for MoE models
- [ ] Better `get_best_configs` modeling
- [ ] Utility kernels for MoE models (maybe with [tile-lang](https://github.com/tile-ai/tilelang))
- [ ] CUDA PDL support
- [ ] More scaling granularity support via templates
- [ ] Larger TMA multicast size for some shapes
- [x] MMA template refactor with CUTLASS
- [ ] Optimizations for power efficiency
- [x] Remove shape limitations on N and K
- [ ] BF16 kernels
- [ ] Split/stream-k optimizations
- [ ] Ampere kernels
- [ ] Polish docs
## Quick start
### Requirements
- Hopper architecture GPUs, `sm_90a` must be supported
- Python 3.8 or above
- CUDA 12.3 or above
- **But we highly recommend 12.8 or above for the best performance**
- PyTorch 2.1 or above
- CUTLASS 3.6 or above (could be cloned by Git submodule)
- NVIDIA SM90 or SM100 architecture GPU
- Python 3.8 or higher
- Compilers with C++20 support
- CUDA Toolkit:
- CUDA 12.3 or higher for SM90
- **We highly recommend 12.9 or higher for the best performance**
- CUDA 12.9 or higher for SM100
- PyTorch 2.1 or higher
- CUTLASS 4.0 or higher (could be cloned by Git submodule)
- `{fmt}` library (could be cloned by Git submodule)
### Development
```bash
# Submodule must be cloned
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
cd DeepGEMM
# Make symbolic links for third-party (CUTLASS and CuTe) include directories
python setup.py develop
# Link some essential includes and build the CPP JIT module
cat develop.sh
./develop.sh
# Test JIT compilation
python tests/test_jit.py
# Test all GEMM implements (normal, contiguous-grouped and masked-grouped)
# Test all GEMM implements
python tests/test_layout.py
python tests/test_core.py
```
### Installation
```bash
python setup.py install
cat install.sh
./install.sh
```
Then, import `deep_gemm` in your Python project, and enjoy!
@@ -75,118 +82,61 @@ Then, import `deep_gemm` in your Python project, and enjoy!
#### Notices
This library exclusively contains GEMM kernels. It requires the LHS scaling factor to be TMA-aligned and transposed, and it only supports the NT format (non-transposed LHS and transposed RHS). For transposition or other FP8 casting operations, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.
This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: `D = C + A @ B`. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, `fp8_gemm_nt` will do a `D = C + A @ B.T`
For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different:
- SM90 requires scaling factors in FP32 format.
- SM100 requires scaling factors in packed [UE8M0](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) format, which packs 4 UE8M0 into a single `torch.int`.
Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.
#### Normal dense GEMMs (non-grouped)
To perform a basic non-grouped FP8 GEMM, call the `deep_gemm.gemm_fp8_fp8_bf16_nt` function. For more details, please refer to the function documentation.
To perform a basic non-grouped FP8 GEMM, call the `fp8_gemm_{nt, nn, tn, tt}` function. For more details, please refer to the function documentation.
#### Grouped GEMMs (contiguous layout)
Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape.
Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_mk_alignment_for_contiguous_layout()`). For more information, please refer to the `m_grouped_fp8_gemm_{nt, nn}_contiguous` function documentation.
For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_m_alignment_for_contiguous_layout()`).
For more information, please refer to the `m_grouped_gemm_fp8_fp8_bf16_nt_contiguous` function documentation.
We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to `k_grouped_fp8_gemm_tn_contiguous` for more information.
#### Grouped GEMMs (masked layout)
During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.
Use `m_grouped_gemm_fp8_fp8_bf16_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input.
Use `fp8_m_grouped_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input.
#### Utilities
The library provides some utility functions besides the above kernels:
- `deep_gemm.set_num_sms`: set the maximum SM count to use
- `deep_gemm.get_num_sms`: get the current SM maximum count
- `deep_gemm.get_m_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout
- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set)
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout
- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size
- `deep_gemm.get_col_major_tma_aligned_tensor`: get a column-major TMA-aligned tensor
- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout
- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor
- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0)
- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel
The library also provides some environment variables, which may be useful:
- General
- `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default
- `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default
- JIT cache related
- `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- `DG_JIT_DISABLE_CACHE`: `0` or `1`, disable the use of cache directory, `0` by default
- `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- NVCC/NVRTC selections
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default
- `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default
- `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default
- Compiler options
- `DG_JIT_OVERRIDE_CPP_STANDARD`: integer (e.g., `20`), support for some old version GCC compiler, `20` by default
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default
- `DG_JIT_PRINT_REG_REUSE`: `0` or `1`, print FFMA-interleaving details, `0` by default
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default
- Post optimization
- `DG_JIT_DISABLE_FFMA_INTERLEAVE`: `0` or `1`, disable FFMA-interleaving optimization, `0` by default
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default
- Heuristic selection
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- Testing
- `DG_NSYS_PROFILING`: `0` or `1`, Nsight-system compatible testing, `0` by default
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.
## Optimizations
We indicate the techniques excluded from CUTLASS with 🐳.
#### Persistent warp-specialization
Following the CUTLASS design, the kernels in DeepGEMM are warp-specialized, enabling overlapping data movement, tensor-core MMA instructions, and CUDA-core promotion. A simplified figure illustrating this process is shown below:
![design](figures/design.png)
#### Hopper TMA features
The [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/hopper-tuning-guide/index.html#tensor-memory-accelerator) (TMA) is a new hardware feature introduced by the Hopper architecture, designed for faster and asynchronous data movement. Specifically, we utilize TMA for:
- TMA load for LHS, LHS scaling factors, and RHS matrices
- TMA store for the output matrix
- TMA multicast (automatically decide LHS or RHS to broadcast)
- TMA descriptor prefetching
#### Common detail optimizations
- Utilization of the [`stmatrix`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-store-instruction-stmatrix) PTX instruction
- [Register count control](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-setmaxnreg) tailored for different warpgroups
- Less bank conflicts via 3D TMA or swizzling
- Larger block sizes (up to 256x128 🐳)
- Overlapping as much as possible, e.g., overlapping TMA store and non-TMA RHS scaling factor load 🐳
#### A unified and optimized block scheduler
- [One scheduler](deep_gemm/include/deep_gemm/scheduler.cuh) for all non-grouped and grouped kernels
- [Rasterization](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/media/docs/efficient_gemm.md#threadblock-rasterization) to enhance L2 cache reuse
#### Fully JIT design 🐳
DeepGEMM employs a fully [Just-In-Time](deep_gemm/jit) (JIT) design, with no compilation required at installation. All kernels are compiled at runtime using a lightweight JIT implementation. This approach offers several advantages:
- GEMM shapes, block sizes, and the number of pipeline stages are treated as compile-time constants
- Saving registers
- Compilers may do more optimizations
- Automatic selection of block sizes, number of warpgroups, optimal pipeline stages, and TMA cluster size
- But without auto-tuning, the optimal one is deterministically selected
- Full unrolling of the MMA pipelines, providing compilers with more optimization opportunities
- Very important for small shapes
- Refer to `launch_k_iterations` in [the kernel file](deep_gemm/include/deep_gemm/fp8_gemm.cuh) for details
Overall, JIT significantly improves performance for small shapes, similar to the approach of the [Triton](https://github.com/triton-lang/triton/) compiler.
#### Unaligned block sizes 🐳
For certain shapes, block sizes aligned to powers of 2 can lead to underutilized SMs. For instance, with `M=256, N=7168`, a typical block size assignment of `BLOCK_M=128, BLOCK_N=128` results in only `(256 / 128) * (7168 / 128) = 112` out of 132 SMs being utilized. To address this, we support unaligned block sizes like 112, enabling `(256 / 128) * (7168 / 112) = 128` SMs to work in such scenarios. Implementing this technique alongside fine-grained scaling requires careful optimization but ultimately delivers performance gains.
#### FFMA SASS interleaving 🐳
We observe a performance improvement in [the CUTLASS FP8 kernel](https://github.com/NVIDIA/cutlass/tree/main/examples/54_hopper_fp8_warp_specialized_gemm) between NVCC 12.2 and 12.3. By comparing the compiled SASS, we discover that one bit in [a series of `FADD` instructions](https://github.com/NVIDIA/cutlass/blob/eefa171318b79cbe2e78514d4cce5cd0fe919d0c/include/cutlass/gemm/collective/fp8_accumulation.hpp#L73) is flipped in an interleaving pattern.
After referencing some open-source [CUDA assembler](https://github.com/cloudcores/CuAssembler/blob/96a9f72baf00f40b9b299653fcef8d3e2b4a3d49/CuAsm/CuControlCode.py#L46) implementations, we identified that this bit controls `yield`, which may enhance warp-level parallelism (just a guess, yielding the current warp and let other warps work).
To leverage this, we develop [a similar script](deep_gemm/jit/interleave_ffma.py) to modify the `FFMA` instructions in the compiled binary. Besides simply modifying the `yield` bit, we also flip the `reuse` bit (registers cannot be reused if the warp is yielded). This adjustment improves performance (10%+ in some cases) for fine-grained scaling FP8 GEMMs by creating more opportunities to overlap MMA instructions with promotion `FFMA` instructions.
## Acknowledgement
DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers!
@@ -194,15 +144,3 @@ DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project
## License
This code repository is released under [the MIT License](LICENSE).
## Citation
```bibtex
@misc{deepgemm2025,
title={DeepGEMM: clean and efficient FP8 GEMM kernels with fine-grained scaling},
author={Chenggang Zhao and Liang Zhao and Jiashi Li and Zhean Xu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}},
}
```

13
csrc/indexing/main.cu Normal file
View File

@@ -0,0 +1,13 @@
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
int main() {
return 0;
}

31
csrc/jit/cache.hpp Normal file
View File

@@ -0,0 +1,31 @@
#pragma once
#include <filesystem>
#include <memory>
#include <unordered_map>
#include "kernel_runtime.hpp"
namespace deep_gemm {
class KernelRuntimeCache {
std::unordered_map<std::filesystem::path, std::shared_ptr<KernelRuntime>> cache;
public:
// TODO: consider cache capacity
KernelRuntimeCache() = default;
std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
// Hit the runtime cache
if (const auto& iterator = cache.find(dir_path); iterator != cache.end())
return iterator->second;
if (KernelRuntime::check_validity(dir_path))
return cache[dir_path] = std::make_shared<KernelRuntime>(dir_path);
return nullptr;
}
};
static auto kernel_runtime_cache = std::make_shared<KernelRuntimeCache>();
} // namespace deep_gemm

172
csrc/jit/compiler.hpp Normal file
View File

@@ -0,0 +1,172 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <filesystem>
#include <fstream>
#include <regex>
#include <string>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/hash.hpp"
#include "../utils/system.hpp"
#include "cache.hpp"
#include "device_runtime.hpp"
namespace deep_gemm {
class Compiler {
std::string library_version;
std::filesystem::path library_root_path;
std::string get_library_version() const {
// Recursively walk through all subdirectories and update hash
std::stringstream ss;
for (const auto& entry: std::filesystem::recursive_directory_iterator(library_include_path / "deep_gemm")) {
if (entry.is_regular_file() and entry.path().extension() == ".cuh") {
std::ifstream file(entry.path(), std::ios::binary);
std::string content((std::istreambuf_iterator(file)), std::istreambuf_iterator<char>());
ss << content;
}
}
return get_hex_digest(ss.str());
}
public:
std::string signature, flags;
std::filesystem::path library_include_path;
std::filesystem::path cache_dir_path;
explicit Compiler(const std::filesystem::path& library_root_path) {
// Static library paths
this->library_root_path = library_root_path;
this->library_include_path = library_root_path / "include";
this->library_version = get_library_version();
// Cache settings
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
if (const auto& env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty())
cache_dir_path = env_cache_dir_path;
// The compiler flags applied to all derived compilers
signature = "unknown-compiler";
std::string ptxas_flags = "--ptxas-options=--register-usage-level=10";
if (get_env<int>("DG_JIT_PTXAS_VERBOSE", 0))
ptxas_flags += ",--verbose";
flags = fmt::format("-std=c++20 --diag-suppress=39,161,174,177,186,940 {}", ptxas_flags);
}
virtual ~Compiler() = default;
std::filesystem::path make_tmp_dir() const {
return make_dirs(cache_dir_path / "tmp");
}
std::filesystem::path get_tmp_file_path() const {
return make_tmp_dir() / get_uuid();
}
void put(const std::filesystem::path& path, const std::string& data) const {
const auto tmp_file_path = get_tmp_file_path();
// Write into the temporary file
std::ofstream out(tmp_file_path, std::ios::binary);
DG_HOST_ASSERT(out.write(data.data(), data.size()));
out.close();
// Atomically replace
std::filesystem::rename(tmp_file_path, path);
}
std::shared_ptr<KernelRuntime> build(const std::string& name, const std::string& code) const {
const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code);
const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature));
// Hit the runtime cache
if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr)
return runtime;
// Create the kernel directory
make_dirs(dir_path);
// Compile into a temporary CUBIN
const auto tmp_cubin_path = get_tmp_file_path();
compile(code, dir_path, tmp_cubin_path);
// Replace into the cache directory
make_dirs(dir_path);
std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin");
// Put into the runtime cache
const auto& runtime = kernel_runtime_cache->get(dir_path);
DG_HOST_ASSERT(runtime != nullptr);
return runtime;
}
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0;
};
class NVCCCompiler final: public Compiler {
std::filesystem::path nvcc_path;
std::pair<int, int> get_nvcc_version() const {
DG_HOST_ASSERT(std::filesystem::exists(nvcc_path));
// Call the version command
const auto& command = std::string(nvcc_path) + " --version";
const auto& [return_code, output] = call_external_command(command);
DG_HOST_ASSERT(return_code == 0);
// The version should be at least 12.3, for the best performance with 12.9
int major, minor;
std::smatch match;
DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))")));
std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor);
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3");
if (major < 12 or (major == 12 and minor < 9))
printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance");
return {major, minor};
}
public:
NVCCCompiler(const std::filesystem::path& library_root_path,
const std::filesystem::path& cuda_home_path_by_torch):
Compiler(library_root_path) {
// Override the compiler signature
nvcc_path = cuda_home_path_by_torch / "bin" / "nvcc";
if (const auto& env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
nvcc_path = env_nvcc_path;
const auto& [nvcc_major, nvcc_minor] = get_nvcc_version();
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
// The override the compiler flags
flags = fmt::format("{} -I{} --gpu-architecture=sm_{}a "
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
"-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda",
flags, library_include_path.c_str(), device_runtime->get_arch());
}
void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override {
// Write the code into the cache directory
const auto& code_path = dir_path / "kernel.cu";
put(code_path, code);
// Compile
const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC command: %s", command.c_str());
const auto& [return_code, output] = call_external_command(command);
if (return_code != 0) {
printf("NVCC compilation failed: %s", output.c_str());
DG_HOST_ASSERT(false and "NVCC compilation failed");
}
// Print PTXAS log
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
printf("%s", output.c_str());
}
};
static std::shared_ptr<Compiler> compiler = nullptr;
} // namespace deep_gemm

View File

@@ -0,0 +1,50 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include "../utils/exception.hpp"
namespace deep_gemm {
class DeviceRuntime {
int num_sms = 0;
std::shared_ptr<cudaDeviceProp> cached_prop;
public:
explicit DeviceRuntime() = default;
std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr)
cached_prop = std::make_shared<cudaDeviceProp>(*at::cuda::getCurrentDeviceProperties());
return cached_prop;
}
std::pair<int, int> get_arch_pair() {
const auto prop = get_prop();
return {prop->major, prop->minor};
}
int get_arch() {
const auto& [major, minor] = get_arch_pair();
return major * 10 + minor;
}
int get_arch_major() {
return get_arch_pair().first;
}
void set_num_sms(const int& new_num_sms) {
DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount);
num_sms = new_num_sms;
}
int get_num_sms() {
if (num_sms == 0)
num_sms = get_prop()->multiProcessorCount;
return num_sms;
}
};
static auto device_runtime = std::make_shared<DeviceRuntime>();
} // namespace deep_gemm

139
csrc/jit/kernel_runtime.hpp Normal file
View File

@@ -0,0 +1,139 @@
#pragma once
#include <cuda_runtime.h>
#include <filesystem>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/system.hpp"
#include "device_runtime.hpp"
namespace deep_gemm {
struct LaunchArgs {
std::pair<int, int> grid_dim;
int num_threads;
int smem_size;
int cluster_dim;
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
};
template <typename T>
concept HasLaunchArgs = requires (const T& t) {
{ t.launch_args } -> std::convertible_to<decltype(t.launch_args)>;
};
class KernelRuntime final {
public:
static std::filesystem::path cuda_home;
cudaLibrary_t library;
cudaKernel_t kernel;
explicit KernelRuntime(const std::filesystem::path& dir_path) {
// NOLINT(*-pro-type-member-init)
const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump";
const auto& cubin_path = dir_path / "kernel.cubin";
if (get_env<int>("DG_JIT_DEBUG"))
printf("Loading CUBIN: %s\n", cubin_path.c_str());
// Find the only symbol
// TODO: use kernel enumeration for newer drivers
const std::vector<std::string> illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"};
const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str()));
DG_HOST_ASSERT(exit_code == 0);
std::istringstream iss(symbols);
std::vector<std::string> symbol_names;
for (std::string line; std::getline(iss, line); ) {
if (line.find("STT_FUNC") == 0 and std::ranges::none_of(illegal_names, [&](const auto& name) { return line.find(name) != std::string::npos; })) {
const auto& last_space = line.rfind(' ');
symbol_names.push_back(line.substr(last_space + 1));
}
}
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Symbol names: ");
for (const auto& symbol: symbol_names)
printf("%s, ", symbol.c_str());
printf("\n");
}
// Load from the library
DG_HOST_ASSERT(symbol_names.size() == 1);
DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, symbol_names[0].c_str()));
}
static void set_cuda_home(const std::string& cuda_home_path_by_torch) {
cuda_home = cuda_home_path_by_torch;
}
static bool check_validity(const std::filesystem::path& dir_path) {
return std::filesystem::exists(dir_path / "kernel.cu") and
std::filesystem::exists(dir_path / "kernel.cubin");
}
~KernelRuntime() noexcept(false) {
const auto& error = cudaLibraryUnload(library);
DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading);
}
};
// Declare after defining
decltype(KernelRuntime::cuda_home) KernelRuntime::cuda_home;
template <typename Derived>
class LaunchRuntime {
public:
template <typename Args> requires HasLaunchArgs<Args>
static std::string generate(const Args& args) {
const auto& code = Derived::generate_impl(args);
if (get_env<int>("DG_JIT_DEBUG", 0))
printf("Generated kernel code: %s\n", code.c_str());
return code;
}
template <typename Args> requires HasLaunchArgs<Args>
static void launch(const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) {
const auto& kernel = kernel_runtime->kernel;
const auto& stream = at::cuda::getCurrentCUDAStream();
const LaunchArgs& launch_args = args.launch_args;
// Set dynamic shared memory size
if (launch_args.smem_size > 0)
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, launch_args.smem_size));
// Launch config
cudaLaunchConfig_t config;
config.gridDim = {static_cast<unsigned>(launch_args.grid_dim.first),
static_cast<unsigned>(launch_args.grid_dim.second),
1};
config.blockDim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
config.dynamicSmemBytes = launch_args.smem_size;
config.stream = stream;
config.numAttrs = 0;
// Clusters
cudaLaunchAttribute attr;
if (launch_args.cluster_dim > 1) {
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {static_cast<unsigned>(launch_args.cluster_dim), 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
}
// Launch in the derived class
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n",
launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads,
launch_args.smem_size, launch_args.cluster_dim, stream.id());
}
Derived::launch_impl(kernel, config, args);
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,298 @@
#pragma once
#include "../../utils/math.hpp"
namespace deep_gemm {
struct MulticastConfig {
int num_multicast;
bool is_multicast_on_a;
MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a):
num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) {
DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2);
}
};
struct SharedMemoryConfig {
int smem_size;
int swizzle_a_mode;
int swizzle_b_mode;
int swizzle_cd_mode;
};
struct ThreadConfig {
int num_threads;
// SM90
int num_tma_threads;
int num_math_threads;
// SM100
int num_non_epilogue_threads;
int num_epilogue_threads;
static ThreadConfig sm90(const int& num_tma_threads,
const int& num_math_threads) {
auto config = ThreadConfig();
config.num_threads = num_tma_threads + num_math_threads;
config.num_tma_threads = num_tma_threads;
config.num_math_threads = num_math_threads;
return config;
}
static ThreadConfig sm100(const int& num_non_epilogue_threads,
const int& num_epilogue_threads) {
auto config = ThreadConfig();
config.num_threads = num_non_epilogue_threads + num_epilogue_threads;
config.num_non_epilogue_threads = num_non_epilogue_threads;
config.num_epilogue_threads = num_epilogue_threads;
return config;
}
};
struct GemmConfig {
// Templated configs
GemmType gemm_type;
KernelType kernel_type;
at::ScalarType ab_dtype, cd_dtype;
cute::UMMA::Major major_a;
cute::UMMA::Major major_b;
bool with_accumulation;
int block_m, block_n, block_k;
int num_stages, num_last_stages;
// Runtime configs
int num_sms;
// Structured configs
MulticastConfig multicast_config;
SharedMemoryConfig smem_config;
ThreadConfig thread_config;
};
static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
const int& num_multicast, const int& num_sms,
const bool& require_divisible) {
const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible;
return divisible and num_sms % num_multicast == 0;
}
static int get_swizzle_mode(const int& block_size, const int& elem_size) {
// `> 0` means interleaving
// 16B actually means non-swizzling (but interleaving)
for (const int& mode: {128, 64, 32, 16}) {
if ((block_size * elem_size) % mode == 0)
return mode;
}
DG_HOST_UNREACHABLE("Unreachable");
}
template <typename ArchSpec>
static SharedMemoryConfig get_smem_config(const KernelType& kernel_type,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages, const MulticastConfig& multicast_config) {
const int& ab_elem_size = static_cast<int>(c10::elementSize(ab_dtype));
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));
const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m);
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
const int& swizzle_cd_mode = get_swizzle_mode(block_n, cd_elem_size);
// Different archs have different epilogue pipelines
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
// A/B shared memory
const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size;
const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size;
// SF shared memory
const auto& [smem_sfa_per_stage, smem_sfb_per_stage] =
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype);
const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k);
// M-barriers and tensor memory pointers
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
// Sum them up
int smem_size = 0;
smem_size += smem_cd;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
smem_size += num_stages * smem_sfa_per_stage;
smem_size += num_stages * smem_sfb_per_stage;
smem_size += smem_extra_sfb;
smem_size += smem_barrier;
smem_size += smem_tmem_ptr;
return SharedMemoryConfig {
.smem_size = smem_size,
.swizzle_a_mode = swizzle_a_mode,
.swizzle_b_mode = swizzle_b_mode,
.swizzle_cd_mode = swizzle_cd_mode,
};
}
template <typename ArchSpec>
static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k, const int& num_groups,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
// Select M/N block sizes
// TODO: support `% 16 == 8` block size on SM90
const auto& block_ms = gemm_type == GemmType::MGroupedContiguous ?
std::vector{get_mk_alignment_for_contiguous_layout()} : std::vector{64, 128, 256};
std::vector<int> block_ns;
for (int i = 16; i <= 256; i += 16)
block_ns.push_back(i);
// K block size is selected in a fixed manner
const auto& block_k = 128 / static_cast<int>(c10::elementSize(ab_dtype));
// Some util functions
const auto& get_num_blocks = [=](const int& block_m, const int& block_n) {
return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups;
};
const auto& get_num_waves = [=](const int& block_m, const int& block_n) {
return ceil_div(get_num_blocks(block_m, block_n), num_sms);
};
const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) {
const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms;
return num_last_blocks == 0 ? num_sms : num_last_blocks;
};
// Decide block sizes by waves
int best_block_m = 0, best_block_n = 0;
int best_num_waves = 0, best_last_util = 0;
for (const auto& block_m: block_ms) {
for (const auto& block_n: block_ns) {
const int& num_waves = get_num_waves(block_m, block_n);
const auto& last_util = get_last_wave_util(block_m, block_n);
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, block_m, block_n))
continue;
bool success = false;
if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) {
success = true;
} else if (num_waves == best_num_waves) {
// Check last wave utilization
success = last_util > best_last_util;
if (last_util == best_last_util) {
// Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n;
// Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m;
// Case 3: different for both `block_m` and `block_n`, larger `block_n` is better
success |= block_m != best_block_m and block_n > best_block_n;
}
}
// Replace with the new config if successful
if (success) {
best_block_m = block_m, best_block_n = block_n;
best_num_waves = num_waves, best_last_util = last_util;
}
}
}
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
// Decide the number of TMA multicasts and whether broadcast on A
MulticastConfig best_multicast_config = {1, true};
const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, m, n, best_block_m, best_block_n, num_sms);
const bool is_legal[2] = {is_legal_on_a, is_legal_on_b};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
std::swap(order[0], order[1]);
for (const bool& is_multicast_on_a: order) {
if (m >= 512 and is_legal[static_cast<int>(is_multicast_on_a)]) {
best_multicast_config = {2, is_multicast_on_a};
break;
}
}
// Always pick the largest number of stage
constexpr int smem_capacity = ArchSpec::smem_capacity;
int best_num_stages = 0;
SharedMemoryConfig best_smem_config;
for (int num_stages = std::min(12, ceil_div(k, block_k)); num_stages > 0; -- num_stages) {
if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
continue;
best_smem_config = get_smem_config<ArchSpec>(kernel_type,
m, n, k,
best_block_m, best_block_n, block_k,
major_a, major_b,
ab_dtype, cd_dtype,
num_stages, best_multicast_config);
if (best_smem_config.smem_size <= smem_capacity) {
best_num_stages = num_stages;
break;
}
}
DG_HOST_ASSERT(best_num_stages != 0);
// Recompute the minimal number of SMs required
// NOTES: less L2 cache usage and less GPU frequency drop
int num_min_sms = num_sms;
if (ArchSpec::should_minimize_num_sms()) {
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
DG_HOST_ASSERT(num_min_sms <= num_sms);
}
const auto& config = GemmConfig {
.gemm_type = gemm_type,
.kernel_type = kernel_type,
.ab_dtype = ab_dtype,
.cd_dtype = cd_dtype,
.major_a = major_a,
.major_b = major_b,
.with_accumulation = with_accumulation,
.block_m = best_block_m,
.block_n = best_block_n,
.block_k = block_k,
.num_stages = best_num_stages,
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
.num_sms = num_min_sms,
.multicast_config = best_multicast_config,
// ReSharper disable once CppLocalVariableMightNotBeInitialized
.smem_config = best_smem_config,
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
};
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b,
ab_dtype, cd_dtype, with_accumulation, num_sms);
static std::set<decltype(key)> printed;
if (not printed.contains(key)) {
printf("Gemm type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
"A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, "
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
"swizzle B: %d, swizzle CD: %d, threads: %d\n",
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
static_cast<int>(major_a), static_cast<int>(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype),
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
static_cast<int>(best_multicast_config.is_multicast_on_a),
best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode,
best_smem_config.swizzle_cd_mode, config.thread_config.num_threads);
printed.insert(key);
}
}
return config;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,144 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.hpp>
#include "common.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
struct SM100ArchSpec {
static constexpr int smem_capacity = 232448;
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
}
static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) {
return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast);
}
static int get_cd_store_block_m(const int& block_m) {
constexpr int layout_ad_m = 128;
return std::min(block_m, layout_ad_m);
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) {
constexpr int num_utccp_aligned_elems = 128;
DG_HOST_ASSERT(block_m % num_utccp_aligned_elems == 0);
switch (ab_dtype) {
case torch::kBFloat16: return {0, 0};
case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
default: DG_HOST_UNREACHABLE("Unknown dtype");
}
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& block_m, const int& block_n) {
// Layout A/D does not support `block_m == 64` and `block_n % 16 != 0`
if (block_m == 64 or block_n % 16 != 0)
return false;
// Performance is lower with 1D1D and `block_m == 256`
if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m != 128)
return false;
// 1D2D kernels' maximum block N is 128
// 1D2D kernels require more friendly block Ns
if (kernel_type == KernelType::Kernel1D2D and (block_n > 128 or 128 % block_n != 0))
return false;
// Check tensor memory validity
int sf_block_m = 0, sf_block_n = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
sf_block_m = sf_block_m_, sf_block_n = sf_block_n_;
}
if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512)
return false;
// NOTES: when B is MN-major, we restrict `block_n` to multiples of 64,
// since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA
return major_b == cute::UMMA::Major::K or block_n % 64 == 0;
}
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
return true;
}
static bool should_minimize_num_sms() {
return false;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// TODO: support other layouts
return {
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous),
false,
};
}
static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm100(128, kernel_type == KernelType::Kernel1D1D ? 128 : block_m);
}
static int get_smem_cd_size(const KernelType& kernel_type,
const int& block_m, const int& block_n,
const int& swizzle_cd_mode,
const at::ScalarType& cd_dtype) {
constexpr static int layout_ad_m = 128;
return (kernel_type == KernelType::Kernel1D1D ? std::min(block_m, layout_ad_m) : block_m) * swizzle_cd_mode * 2;
}
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
if (ab_dtype == torch::kBFloat16)
return {0, 0};
int smem_sfa_per_stage = 0;
int smem_sfb_per_stage = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype);
smem_sfa_per_stage = sf_block_m * 4;
smem_sfb_per_stage = sf_block_n * 4;
} else {
smem_sfa_per_stage = block_m * 4;
smem_sfb_per_stage = 0;
}
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
return 0;
}
static int get_barrier_smem_size(const int& num_stages) {
// TODO: remove SF barriers for BF16 GEMMs
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
// NOTES: 1D2D kernel will not use the with-SF full barriers
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
return num_stages * 8 * 3 + 2 * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 4;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,115 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.hpp>
#include "common.hpp"
namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
return block_m;
}
static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) {
return block_n;
}
static int get_cd_store_block_m(const int& block_m) {
return block_m;
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& block_m, const int& block_n) {
// FP32 output does not support `block_m == 256`
if (cd_dtype == at::kFloat and block_m == 256)
return false;
// Must be some fixed block N selections
if (block_n > 128 and kernel_type == KernelType::Kernel1D1D and (block_n != 136 or block_n != 152))
return false;
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 or block_n != 160))
return false;
// Avoid bank conflicts for FP32 output
if (cd_dtype == torch::kFloat and block_n % 16 == 0)
return false;
// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
return block_m <= 128 or block_n <= 128;
}
static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
// Unrolling both stages and `num_former_iters` will cause large code size
if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
return num_stages <= 4;
return true;
}
static bool should_minimize_num_sms() {
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
is_multicast_legal(m, block_m, 2, num_sms, false) and gemm_type != GemmType::MGroupedMasked,
};
}
static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm90(128, (block_m == 64 ? 1 : 2) * 128);
}
static int get_smem_cd_size(const KernelType& kernel_type,
const int& block_m, const int& block_n,
const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) {
return block_m * block_n * static_cast<int>(c10::elementSize(cd_dtype));
}
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) {
if (ab_dtype == torch::kBFloat16)
return {0, 0};
int smem_sfa_per_stage = block_m * static_cast<int>(sizeof(float));
int smem_sfb_per_stage = 0;
// TODO: figure out here
if (kernel_type == KernelType::Kernel1D1D)
smem_sfb_per_stage = align(block_n * 4, block_k);
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
}
static int get_barrier_smem_size(const int& num_stages) {
// For 1D1D kernels, there is an extra barrier for accumulation
return (num_stages + 1) * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 0;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,173 @@
#pragma once
#include <cuda.h>
#include <torch/python.h>
#include "../../utils/math.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
static std::pair<int, int> get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) {
return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k);
}
static int get_non_contiguous_dim(const cute::UMMA::Major& major) {
return major == cute::UMMA::Major::K ? -2 : -1;
}
static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) {
for (const char& c: compiled_dims) {
if (name == c)
return dim;
}
return 0;
}
static std::string to_string(const cute::UMMA::Major& major) {
switch (major) {
case cute::UMMA::Major::K: return "cute::UMMA::Major::K";
case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN";
}
DG_HOST_UNREACHABLE("Unknown major");
}
static std::string to_string(const GemmType& type) {
switch (type) {
case GemmType::Normal: return "GemmType::Normal";
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
}
DG_HOST_UNREACHABLE("Unknown GEMM type");
}
static std::string to_string(const at::ScalarType& dtype) {
switch (dtype) {
case torch::kInt: return "int";
case torch::kFloat: return "float";
case torch::kBFloat16: return "cutlass::bfloat16_t";
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype) {
switch (dtype) {
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode) {
switch (mode) {
case 0: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;
default: DG_HOST_UNREACHABLE("Unsupported swizzling mode");
}
}
static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
int gmem_inner_dim, int gmem_outer_dim,
int smem_inner_dim, int smem_outer_dim,
const int& gmem_outer_stride,
const int& swizzle_mode) {
const auto& elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
const cuuint32_t elem_strides[2] = {1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d, elem size: %d\n",
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
gmem_outer_stride, swizzle_mode, elem_size);
}
DG_CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type()),
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_m, const int& shape_k,
const int& block_m, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
if (num_groups > 1)
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
}
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_n, const int& shape_k,
const int& block_n, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
// `num_groups` is always applied into the outer dimensions
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim * num_groups,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode);
}
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
const int& shape_m, const int& shape_n,
const int& block_m, const int& block_n,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode) {
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
return make_tma_2d_desc(t,
shape_n, shape_m * num_groups,
block_n, block_m,
outer_stride,
swizzle_mode);
}
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
int shape_mn, int shape_k,
const int& block_mn, const int& block_k,
const int& num_groups,
const int& swizzle_mode) {
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
// TODO: maybe swizzle SF as well
DG_HOST_ASSERT(swizzle_mode == 0);
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
return make_tma_2d_desc(t,
shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
block_mn, 1,
shape_mn,
swizzle_mode);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,351 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8Gemm1D1DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void* grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_c;
CUtensorMap tensor_map_d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{}, {},
{}, {},
{}, {},
{},
{}, {}
>);
}};
)",
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
to_string(args.gemm_config.gemm_type),
args.gemm_config.with_accumulation,
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.grouped_layout, args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_c, args.tensor_map_d));
}
};
static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto& cd = c.value_or(d);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(cd.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, 1, 0);
// Duplicate the accumulator if necessary
if (c.has_value()) {
if (c->data_ptr() == d.data_ptr()) {
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
} else {
// ReSharper disable once CppExpressionWithoutSideEffects
d.copy_(c.value());
}
}
// Launch
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_c,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D1D,
m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, num_groups, 0);
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_d,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_fp8_m_grouped_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, num_groups, 0);
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_d,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0, sum_sf_k = 0;
for (const auto& k: ks) {
sum_k += k, sum_sf_k += ceil_div(k, 512);
DG_HOST_ASSERT(k % 128 == 0);
}
const auto& num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto& max_k = *std::ranges::max_element(ks);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& cd = c.value_or(d);
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(0)), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(0)), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_c = make_tma_cd_desc(cd, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(cd.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512,
config.block_m, config.block_k, num_groups, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512,
config.block_n, config.block_k, num_groups, 0);
// Duplicate the accumulator if necessary
if (c.has_value()) {
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr());
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
}
// Launch kernel
const SM100FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_c = tensor_map_c,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,242 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100FP8Gemm1D2DRuntime final: public LaunchRuntime<SM100FP8Gemm1D2DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *sfb, *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
CUtensorMap tensor_map_sfa;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm100_fp8_gemm_1d2d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d2d_impl<
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{}, {},
{}, {},
{}, {},
{}, {}
>);
}};
)",
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
to_string(args.gemm_config.gemm_type),
to_string(args.gemm_config.cd_dtype));
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.sfb, args.grouped_layout,
args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
}
};
static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(not c.has_value());
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM100FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM100FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm100_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
// Launch
const SM100FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code);
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,255 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *sfb, *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
CUtensorMap tensor_map_sfa;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{}, {},
{}
>);
}};
)",
// TODO: add CD dtype
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
to_string(args.gemm_config.gemm_type));
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.sfb, args.grouped_layout,
args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
}
};
static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& aligned_k = align(k, 128);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_fp8_m_grouped_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& aligned_k = align(k, 128);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
torch::kFloat8_e4m3fn, d.scalar_type(), false,
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.m = m, .n = n, .k = aligned_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,199 @@
#pragma once
#include <torch/python.h>
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../../utils/layout.hpp"
namespace deep_gemm {
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
public:
struct Args {
int mn, sf_k;
int block_mn;
void *sf, *out;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&transpose_and_pack_fp32_into_ue8m0<
{}, {}, {}
>);
}};
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel, args.sf, args.out, static_cast<uint32_t>(args.mn)));
}
};
class PackFP32IntoUE8M0Runtime final: public LaunchRuntime<PackFP32IntoUE8M0Runtime> {
public:
struct Args {
int num_groups, mn, sf_k, packed_sf_k;
int block_mn, block_packed_sf_k;
void *sf, *out, *ks;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&pack_fp32_into_ue8m0<
{}, {}, {}, {}
>);
}};
)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k);
}
static void launch_impl(const cudaKernel_t& kernel, const cudaLaunchConfig_t& config, Args args) {
DG_CUDA_RUNTIME_CHECK(cudaLaunchKernelEx(&config, kernel,
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k));
}
};
static std::tuple<int, int, int, int, int, torch::Tensor> preprocess_sf(const torch::Tensor& sf) {
// NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
const auto& dim = sf.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat);
const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf};
}
static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
// The last kernel already gives a column-major TMA aligned layout
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
// Normal layout requires transposing
auto aligned_sf = torch::empty_strided({num_groups, tma_aligned_mn, sf_k}, {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, batched_sf.options());
aligned_sf = aligned_sf.slice(1, 0, mn).copy_(batched_sf);
return (dim == 2) ? aligned_sf.squeeze(0) : aligned_sf;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto& packed_sf_k = ceil_div(sf_k, 4);
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
DG_HOST_ASSERT(num_groups == 1 or (mn * sf_k) % 4 == 0);
// Launch the kernel
if (batched_sf.is_contiguous()) {
constexpr int block_mn = 48;
constexpr int num_threads = 512;
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
.mn = mn,
.sf_k = sf_k,
.block_mn = block_mn,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4)
};
const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
} else {
DG_HOST_ASSERT(mn % 4 == 0 and num_groups == 1);
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = 1,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.ks = nullptr,
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
}
return (dim == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf,
const torch::Tensor& ks_tensor,
const std::vector<int>& ks) {
const auto& [sf_k, mn] = get_shape<2>(sf);
const auto& num_groups = static_cast<int>(ks.size());
int ref_sf_k = 0, packed_sf_k = 0;
for (const auto& k: ks)
ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512);
DG_HOST_ASSERT(sf.is_contiguous());
DG_HOST_ASSERT(ref_sf_k == sf_k);
DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0);
const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = num_groups,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = sf.data_ptr(),
.out = out.data_ptr(),
.ks = ks_tensor.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
return out;
}
} // namespace deep_gemm

402
csrc/python_api.cpp Normal file
View File

@@ -0,0 +1,402 @@
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "jit/compiler.hpp"
#include "jit/device_runtime.hpp"
#include "utils/layout.hpp"
#include "jit_kernels/impls/smxx_layout.hpp"
#include "jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
#include "jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME deep_gemm_cpp
#endif
namespace deep_gemm {
torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const std::optional<int>& num_groups,
const std::tuple<int, int, int>& recipe,
const bool& is_sfa,
const bool& disable_ue8m0_cast) {
const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe);
const auto& gran_k = std::get<2>(recipe);
const auto& arch_major = device_runtime->get_arch_major();
// Pre-transform checks
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return get_mn_major_tma_aligned_tensor(sf);
// (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf);
}
// (FP32, 128, 128) on SM90: no need to transform, check shape and contiguous
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
// (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128));
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
}
// (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
DG_HOST_UNREACHABLE("Unknown SF transformation");
}
torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::tuple<int, int, int>& recipe) {
DG_HOST_ASSERT(sf.dim() == 2);
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
const auto& arch_major = device_runtime->get_arch_major();
// FP32 on SM90
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
DG_HOST_UNREACHABLE("Unimplemented");
// FP32 on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
// INT on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
DG_HOST_UNREACHABLE("Unimplemented");
DG_HOST_UNREACHABLE("Unknown cases");
}
void fp8_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
}
// C/D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a.first);
const auto& [n , k_] = get_shape<2>(b.first);
const auto& [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Check C as well
if (c.has_value()) {
check_major_type_cd(c.value());
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
}
// Do nothing if the problem is empty
if (m == 0)
return;
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, std::nullopt, recipe.value(), false, disable_ue8m0_cast);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
sm100_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
}
}
void fp8_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
}
void fp8_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, compiled_dims, disable_ue8m0_cast);
}
void fp8_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
d, c, recipe, compiled_dims, disable_ue8m0_cast);
}
void m_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
if (fp8_requires_k_major())
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(m_indices.is_contiguous());
// Type and shape checks
const auto& [m, k] = get_shape<2>(a.first);
const auto& [num_groups, n, k_] = get_shape<3>(b.first);
const auto& [m_, n_] = get_shape<2>(d);
const auto& m__ = static_cast<int>(m_indices.numel());
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt);
// D must be N-major
check_major_type_cd(d);
// Do nothing if empty
if (m == 0)
return;
// Transform SFA and SFB into compute-required layout
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, std::nullopt, recipe.value(), true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
num_groups, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices,
num_groups, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
sm100_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices,
num_groups, m, n, k, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unknown kernel or scaling factor types");
}
}
void m_grouped_fp8_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
d, m_indices, recipe, compiled_dims, disable_ue8m0_cast);
}
void fp8_m_grouped_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());
// Type and shape checks
const auto& [num_groups, m, k] = get_shape<3>(a.first);
const auto& [num_groups_, n, k_] = get_shape<3>(b.first);
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
const auto& num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
// D must be N-major
check_major_type_cd(d);
// Transform scaling factors
if (not recipe.has_value())
recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type());
const auto& sfa = transform_sf_into_required_layout(a.second, m, k, num_groups, recipe.value(), true, disable_ue8m0_cast);
const auto& sfb = transform_sf_into_required_layout(b.second, n, k, num_groups, recipe.value(), false, disable_ue8m0_cast);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
sm90_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_m_grouped_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kFloat) {
sm100_fp8_m_grouped_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported kernel or scaling factor types");
}
}
void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe,
const std::string& compiled_dims) {
// Must be 1D1D kernel
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
// Contiguity checks
DG_HOST_ASSERT(a.first.is_contiguous());
DG_HOST_ASSERT(b.first.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
if (c.has_value()) {
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
DG_HOST_ASSERT(c.value().is_contiguous());
}
// Do nothing if empty
if (std::accumulate(ks.begin(), ks.end(), 0) == 0)
return;
// Transform SF with padding
const auto& [_, m] = get_shape<2>(a.first);
const auto& [__, n] = get_shape<2>(b.first);
const auto& sfa = transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto& sfb = transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
} // namespace deep_gemm
// ReSharper disable once CppParameterMayBeConstPtrOrRef
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
using namespace deep_gemm;
m.doc() = "DeepGEMM C++ library";
// Runtime
m.def("get_num_sms", [&]() {
return device_runtime->get_num_sms();
});
m.def("set_num_sms", [&](const int& new_num_sms) {
device_runtime->set_num_sms(new_num_sms);
});
// JIT
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_torch) {
DG_HOST_ASSERT(get_env("DG_JIT_USE_NVRTC", 0) == 0 and "Currently only support NVCC");
compiler = std::make_shared<NVCCCompiler>(library_root_path, cuda_home_path_by_torch);
KernelRuntime::set_cuda_home(cuda_home_path_by_torch);
});
// Stable kernel APIs with automatic arch/layout dispatch
m.def("fp8_gemm_nt", &fp8_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_nn", &fp8_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_tn", &fp8_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_gemm_tt", &fp8_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_m_grouped_gemm_nt_masked", &fp8_m_grouped_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout);
// Raw kernels or functions
m.def("get_tma_aligned_size", &get_tma_aligned_size);
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
}

58
csrc/utils/exception.hpp Normal file
View File

@@ -0,0 +1,58 @@
#pragma once
#include <exception>
#include <string>
namespace deep_gemm {
class DGException final : public std::exception {
std::string message = {};
public:
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
message = std::string("Failed: ") + name + " error " + file + ":" + std::to_string(line) + " '" + error + "'";
}
const char *what() const noexcept override {
return message.c_str();
}
};
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#endif
#ifndef DG_HOST_ASSERT
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
throw DGException("Assertion", __FILE__, __LINE__, #cond); \
} \
} while (0)
#endif
#ifndef DG_HOST_UNREACHABLE
#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason))
#endif
#ifndef DG_CUDA_DRIVER_CHECK
#define DG_CUDA_DRIVER_CHECK(cmd) \
do { \
const auto& e = (cmd); \
if (e != CUDA_SUCCESS) { \
throw DGException("CUDA driver", __FILE__, __LINE__, ""); \
} \
} while (0)
#endif
#ifndef DG_CUDA_RUNTIME_CHECK
#define DG_CUDA_RUNTIME_CHECK(cmd) \
do { \
const auto& e = (cmd); \
if (e != cudaSuccess) { \
throw DGException("CUDA runtime", __FILE__, __LINE__, std::to_string(static_cast<int>(e))); \
} \
} while (0)
#endif
} // namespace deep_gemm

6
csrc/utils/format.hpp Normal file
View File

@@ -0,0 +1,6 @@
#pragma once
// Just a wrapper for the `fmt` headers
#define FMT_HEADER_ONLY
#include <fmt/base.h>
#include <fmt/format.h>

35
csrc/utils/hash.hpp Normal file
View File

@@ -0,0 +1,35 @@
#pragma once
#include <string>
namespace deep_gemm {
static uint64_t fnv1a(const std::string& data, const uint64_t& seed) {
uint64_t h = seed;
const uint64_t& prime = 0x100000001b3ull;
for (const char& c: data) {
h ^= static_cast<uint8_t>(c);
h *= prime;
}
return h;
}
static std::string get_hex_digest(const std::string& data) {
const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull);
const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull);
// Split-mix 64
const auto& split_mix = [](uint64_t z) {
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull;
z = (z ^ (z >> 27)) * 0x94d049bb133111ebull;
return z ^ (z >> 31);
};
std::ostringstream oss;
oss << std::hex << std::setfill('0')
<< std::setw(16) << split_mix(state_0)
<< std::setw(16) << split_mix(state_1);
return oss.str();
}
} // namespace deep_gemm

100
csrc/utils/layout.hpp Normal file
View File

@@ -0,0 +1,100 @@
#pragma once
#include <cute/arch/mma_sm100_umma.hpp>
#include <torch/python.h>
#include "math.hpp"
#include "exception.hpp"
#include "../jit/device_runtime.hpp"
namespace deep_gemm {
// Major-ness stuffs
static void major_check(const torch::Tensor& t) {
const auto dim = t.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
if (dim == 3)
DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1));
DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1);
}
static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) {
major_check(t);
return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
}
static void check_major_type_cd(const torch::Tensor& t) {
// NOTES: the library only supports row-major output layouts
major_check(t);
DG_HOST_ASSERT(t.stride(-1) == 1);
}
static bool fp8_requires_k_major() {
return device_runtime->get_arch_major() == 9;
}
// Tensor utils
template <int N>
static auto get_shape(const torch::Tensor& t) {
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
}(std::make_index_sequence<N>());
}
// Recipe
static std::tuple<int, int, int>
get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) {
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat);
return {1, 128, 128};
} else if (arch_major == 10) {
DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt);
return sfb_dtype == torch::kFloat ?
std::make_tuple(1, 128, 128): // Legacy format or 1D2D kernels
std::make_tuple(1, 1, 128); // 1D1D kernels
}
DG_HOST_UNREACHABLE("Unknown recipe");
}
// SF layouts
static torch::Tensor check_sf_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const int& gran_mn, const int& gran_k,
const std::optional<int>& num_groups,
const bool& tma_stride_check = false,
const bool& contiguous_check = false,
const std::optional<torch::ScalarType>& type_check = std::nullopt) {
// Type check
if (type_check.has_value())
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
// Always do shape checks
const auto& sf_dtype = sf.scalar_type();
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
if (num_groups.has_value())
DG_HOST_ASSERT(sf.size(-3) == num_groups.value());
DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn));
DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)));
// TMA stride checks: TMA aligned and MN-major
if (tma_stride_check) {
if (num_groups.has_value())
DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1));
DG_HOST_ASSERT(sf.stride(-2) == 1);
DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()));
}
// Hopper SFB must be contiguous
if (contiguous_check)
DG_HOST_ASSERT(sf.is_contiguous());
return sf;
}
// Value matrix layout
static int get_mk_alignment_for_contiguous_layout() {
return 128;
}
} // namespace deep_gemm

25
csrc/utils/math.hpp Normal file
View File

@@ -0,0 +1,25 @@
#pragma once
#include <torch/python.h>
#include "exception.hpp"
namespace deep_gemm {
template <typename T>
static T ceil_div(const T& a, const T& b) {
return (a + b - 1) / b;
}
template <typename T>
static constexpr T align(const T& a, const T& b) {
return ceil_div(a, b) * b;
}
static int get_tma_aligned_size(const int& x, const int& element_size) {
constexpr int kNumTMAAlignmentBytes = 16;
DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0);
return align(x, kNumTMAAlignmentBytes / element_size);
}
} // namespace deep_gemm

70
csrc/utils/system.hpp Normal file
View File

@@ -0,0 +1,70 @@
#pragma once
#include <random>
#include <string>
#include <memory>
#include "exception.hpp"
namespace deep_gemm {
// ReSharper disable once CppNotAllPathsReturnValue
template <typename dtype_t>
static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) {
const auto& c_str = std::getenv(name.c_str());
if (c_str == nullptr)
return default_value;
// Read the env and convert to the desired type
if constexpr (std::is_same_v<dtype_t, std::string>) {
return std::string(c_str);
} else if constexpr (std::is_same_v<dtype_t, int>) {
int value;
std::sscanf(c_str, "%d", &value);
return value;
} else {
DG_HOST_ASSERT(false and "Unexpected type");
}
}
static std::tuple<int, std::string> call_external_command(std::string command) {
command = command + " 2>&1";
const auto& deleter = [](FILE* f) { if (f) pclose(f); };
std::unique_ptr<FILE, decltype(deleter)> pipe(popen(command.c_str(), "r"), deleter);
DG_HOST_ASSERT(pipe != nullptr);
std::array<char, 512> buffer;
std::string output;
while (fgets(buffer.data(), buffer.size(), pipe.get()))
output += buffer.data();
const auto& exit_code = WEXITSTATUS(pclose(pipe.release()));
return {exit_code, output};
}
static std::filesystem::path make_dirs(const std::filesystem::path& path) {
// OK if existed
std::error_code capture;
const bool& created = std::filesystem::create_directories(path, capture);
DG_HOST_ASSERT(created or capture.value() == 0);
if (created and get_env<int>("DG_JIT_DEBUG"))
printf("Create directory: %s\n", path.c_str());
return path;
}
static std::string get_uuid() {
static std::random_device rd;
static std::mt19937 gen([]() {
return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count();
}());
static std::uniform_int_distribution<uint32_t> dist;
std::stringstream ss;
ss << getpid() << "-"
<< std::hex << std::setfill('0')
<< std::setw(8) << dist(gen) << "-"
<< std::setw(8) << dist(gen) << "-"
<< std::setw(8) << dist(gen);
return ss.str();
}
} // deep_gemm

View File

@@ -1,15 +1,41 @@
import os
import torch
import torch.utils.cpp_extension
from . import jit
from .jit_kernels import (
gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked,
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt,
ceil_div,
set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout
# Set some default environment provided at setup
try:
# noinspection PyUnresolvedReferences
from .envs import persistent_envs
for key, value in persistent_envs.items():
if key not in os.environ:
os.environ[key] = value
except ImportError:
pass
# Import functions from the CPP module
import deep_gemm_cpp
deep_gemm_cpp.init(
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
torch.utils.cpp_extension.CUDA_HOME # CUDA home
)
from .utils import bench, bench_kineto, calc_diff
# Configs
from deep_gemm_cpp import (
set_num_sms,
get_num_sms
)
# Kernels
from deep_gemm_cpp import (
fp8_gemm_nt, fp8_gemm_nn,
fp8_gemm_tn, fp8_gemm_tt,
m_grouped_fp8_gemm_nt_contiguous,
m_grouped_fp8_gemm_nn_contiguous,
fp8_m_grouped_gemm_nt_masked,
k_grouped_fp8_gemm_tn_contiguous
)
# Some utils
from . import testing
from . import utils
from .utils import *

View File

@@ -0,0 +1,213 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm {
enum class KGroupedIndexType {
MN,
K,
SF_K,
};
#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups,
uint32_t kNumMulticast, bool kIsMulticastOnA,
// TODO: refactor this by other values
uint32_t kNum1DBlocksPerGroup = 16>
struct Scheduler {
int current_iter = -1;
// Block configs
uint32_t num_blocks;
uint32_t num_m_blocks;
uint32_t num_n_blocks;
// For SM90 multicast checks
uint32_t num_blocks_in_group;
bool is_peer_cta_alive = true;
// For grouped GEMM
int* grouped_layout;
uint32_t current_group_idx;
// Only used for masked layout
uint32_t current_m_cumsum;
// Only used for k-grouped layout
uint32_t current_shape_k, current_num_valid_groups, current_k_cumsum, current_sf_k_cumsum;
// ReSharper disable once CppPossiblyUninitializedMember
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n,
int* grouped_layout = nullptr) {
num_m_blocks = ceil_div(shape_m, BLOCK_M);
num_n_blocks = ceil_div(shape_n, BLOCK_N);
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_m_blocks * num_n_blocks;
} else if (kGemmType == GemmType::MGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::MGroupedMasked) {
current_group_idx = current_m_cumsum = 0;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::KGroupedContiguous) {
current_group_idx = current_num_valid_groups = 0;
current_k_cumsum = current_sf_k_cumsum = 0;
current_shape_k = __ldg(grouped_layout + current_group_idx);
this->grouped_layout = grouped_layout;
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
const auto& group_idx = block_idx / num_blocks_per_group;
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
// Fix unaligned TMA multicast
// NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
// while SM100 uses 2-CTA, which can not be dynamically disabled
#if __CUDA_ARCH__ < 1000
if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
num_blocks_in_group = num_blocks_in_group ^ 1;
} else {
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
first_block_idx += num_blocks_in_group ^ 1;
num_blocks_in_group = 1;
}
}
#endif
// Convert to final M/N block indices
if constexpr (kIsMulticastOnA) {
m_block_idx = in_group_idx / num_blocks_in_group;
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
} else {
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
n_block_idx = in_group_idx / num_blocks_in_group;
}
}
template <bool kWithGroupOffset, KGroupedIndexType kIndexType = KGroupedIndexType::MN>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
const auto offset = kWithGroupOffset ? std::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
const auto offset = kWithGroupOffset ? current_group_idx : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
auto offset = 0;
if constexpr (kWithGroupOffset) {
if constexpr (kIndexType == KGroupedIndexType::MN)
offset = current_group_idx * shape_dim;
else if constexpr (kIndexType == KGroupedIndexType::K)
offset = current_k_cumsum;
else if constexpr (kIndexType == KGroupedIndexType::SF_K)
offset = current_sf_k_cumsum;
}
return offset + block_idx * block_size;
}
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
if constexpr (kGemmType == GemmType::MGroupedMasked) {
while (true) {
// End of the task
if (current_group_idx == kNumGroups)
return false;
// Within current group
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + current_group_idx)), BLOCK_M);
const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * num_n_blocks)
break;
// Move to check the next group
current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
} else if (kGemmType == GemmType::KGroupedContiguous) {
while (true) {
// End of the task
if (current_group_idx == kNumGroups)
return false;
// Within current group
if (current_shape_k > 0 and next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
break;
// Move to check the next group
if (current_shape_k > 0) {
current_k_cumsum += current_shape_k;
current_sf_k_cumsum += ceil_div(current_shape_k, 512u);
current_num_valid_groups ++;
}
if ((++ current_group_idx) != kNumGroups)
current_shape_k = __ldg(grouped_layout + current_group_idx);
}
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
} else {
if (next_block_idx >= num_blocks)
return false;
// For SM90 only
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
is_peer_cta_alive = kNum1DBlocksPerGroup % kNumMulticast == 0 or // Always aligned on N (constant bypass)
num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
// For SM90 only
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
if constexpr (kIsMulticastOnA) {
return true;
} else {
const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
}
}
}
// For SM90 only
// ReSharper disable once CppNotAllPathsReturnValue
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
if constexpr (kGemmType == GemmType::Normal) {
return true;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
}
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@@ -0,0 +1,169 @@
#pragma once
#include <cute/atom/mma_traits_sm100.hpp>
#include <cute/arch/mma_sm100_umma.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm::sm100 {
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
constexpr uint32_t get_inner_block_atom_size() {
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
}
template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
uint32_t kSwizzleMode, uint32_t kNumMulticast,
typename dtype_t>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
dtype_t* smem_ptr, const uint32_t& inner_idx, const int32_t& outer_idx) {
DG_STATIC_ASSERT(1 <= kNumMulticast and kNumMulticast <= 2, "Invalid multicast config");
DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
// 2-CTA function will send signals to the leader CTA only
const auto copy_func = kNumMulticast == 1 ? cute::SM90_TMA_LOAD_2D::copy : cute::SM100_TMA_2SM_LOAD_2D::copy;
// Issue multiple TMAs
constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
copy_func(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
}
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr,
uint32_t stride_byte_offset, uint32_t leading_byte_offset) {
cute::UMMA::SmemDescriptor desc;
// Set the version for SM100
desc.version_ = 1;
// Legacy mode
desc.lbo_mode_ = 0;
// Layout
desc.layout_type_ = static_cast<uint8_t>(layout);
// Start address
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
// Base offset
desc.base_offset_ = 0;
// SBO and LBO
desc.stride_byte_offset_ = stride_byte_offset >> 4;
desc.leading_byte_offset_ = leading_byte_offset >> 4;
return desc;
}
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) {
// NOTES: the UTCCP layout is K-major by default
// Atom size: 8 x 128 bits
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero
return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0);
}
__device__ __forceinline__
void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) {
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
}
// ReSharper disable once CppNotAllPathsReturnValue
template <uint32_t kSwizzleMode>
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
kSwizzleMode == 32 or kSwizzleMode == 64 or
kSwizzleMode == 128, "Invalid swizzling mode");
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B;
if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B;
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
constexpr uint32_t get_umma_desc_stride_k() {
return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
return base + ((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) >> 4u);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
const uint32_t stride_byte_offset = 8 * BLOCK_K * sizeof(dtype_t);
const uint32_t leading_byte_offset = 0;
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
} else {
constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
// Must have no in-atom MN-idx
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
uint32_t stride_byte_offset = 8 * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
if constexpr (kSwizzleMode == 16)
swap(stride_byte_offset, leading_byte_offset);
return make_smem_desc(to_umma_layout_type<kSwizzleMode>(),
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
}
}
__device__ __forceinline__
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) {
desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id;
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
}
template <uint32_t kNumCols>
__device__ constexpr uint32_t get_num_aligned_tmem_cols() {
DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
if (kNumCols <= 32) return 32;
if (kNumCols <= 64) return 64;
if (kNumCols <= 128) return 128;
if (kNumCols <= 256) return 256;
return 512;
}
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
} // namespace `deep_gemm::sm100`

View File

@@ -1,149 +1,14 @@
#pragma once
#ifndef __CUDACC_RTC__
#include <cuda.h>
#endif
#include <cstdint>
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include "utils.cuh"
namespace deep_gemm {
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
}
};
template <typename dtype_t>
struct SM90_U32x4_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n"
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
}
};
__forceinline__ __device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
__forceinline__ __device__ void warpgroup_commit_batch() {
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
}
__forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
asm volatile("" : "+f"(reg) :: "memory");
}
__forceinline__ __device__ uint32_t get_lane_id() {
uint32_t lane_id;
asm("mov.u32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* __restrict__ ptr) {
uint32_t ret;
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ int4 ld_shared(const int4* __restrict__ ptr) {
int4 ret;
asm volatile("ld.shared.v4.s32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float ld_shared(const float* __restrict__ ptr) {
float ret;
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float2 ld_shared(const float2* __restrict__ ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(ptr), "f"(val.x), "f"(val.y));
}
template <int N>
__device__ void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
}
union GmmaDescriptor {
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
uint16_t reg16_[4];
struct {
uint16_t start_address_: 14, : 2;
uint16_t leading_byte_offset_: 14, : 2;
uint16_t stride_byte_offset_: 14, : 2;
uint8_t : 1, base_offset_: 3, : 4;
uint8_t : 6, layout_type_: 2;
} bitfield;
// Decay to an `uint64_t`
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
};
template <class PointerType>
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, int layout_type,
int leading_byte_offset = 0,
int stride_byte_offset = 1024) {
GmmaDescriptor desc;
auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
}
namespace deep_gemm::sm90 {
template <int N_, typename MMA>
struct FP8MMA {
template <size_t ...Idx>
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, std::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
@@ -194,19 +59,93 @@ struct FP8MMASelector {
using type = decltype(select_type());
};
enum class Layout {
RowMajor,
ColMajor
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "l"(smem_dst), "r"(src[0]), "r"(src[1]));
}
};
__device__ __host__ constexpr int get_num_math_warpgroups(int block_m) {
return block_m == 64 ? 1 : 2;
__forceinline__ __device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
template <uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup>
__device__ __host__ constexpr int get_num_threads_per_sm(int block_m) {
DG_STATIC_ASSERT(kNumMathThreadsPerGroup == 128, "Only support 128 threads per math group");
return get_num_math_warpgroups(block_m) * kNumMathThreadsPerGroup + kNumTMAThreads;
__forceinline__ __device__ void warpgroup_commit_batch() {
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
}
} // namespace deep_gemm
__forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
asm volatile("" : "+f"(reg) :: "memory");
}
template <int N>
__forceinline__ __device__ void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
}
// TODO: replace with CUTLASS solution
union GmmaDescriptor {
__host__ __device__ constexpr GmmaDescriptor() noexcept: desc_(0) {}
__host__ __device__ constexpr GmmaDescriptor(uint64_t desc) noexcept: desc_(desc) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor const &t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor(GmmaDescriptor &&t) noexcept: desc_(t.desc_) {}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor const &t) noexcept {
desc_ = t.desc_;
return *this;
}
__host__ __device__ constexpr GmmaDescriptor &operator=(GmmaDescriptor &&t) noexcept {
desc_ = t.desc_;
return *this;
}
uint64_t desc_;
uint32_t reg32_[2];
uint16_t reg16_[4];
struct {
uint16_t start_address_: 14, : 2;
uint16_t leading_byte_offset_: 14, : 2;
uint16_t stride_byte_offset_: 14, : 2;
uint8_t : 1, base_offset_: 3, : 4;
uint8_t : 6, layout_type_: 2;
} bitfield;
// Decay to an `uint64_t`
__host__ __device__ constexpr operator uint64_t() const noexcept { return desc_; }
};
template <class PointerType>
__device__ GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type,
const int& leading_byte_offset = 0,
const int& stride_byte_offset = 1024) {
GmmaDescriptor desc;
const auto& uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
}
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
const uint32_t& crd_0, const uint32_t& crd_1, const uint32_t& num_tma_multicast) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
if (num_tma_multicast == 1) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
} else if (cute::block_rank_in_cluster() == 0) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
}
}
} // namespace `deep_gemm::sm90`

View File

@@ -0,0 +1,17 @@
#pragma once
namespace deep_gemm {
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
};
enum class KernelType {
Kernel1D1D = 0,
Kernel1D2D = 1,
};
} // namespace deep_gemm

View File

@@ -0,0 +1,138 @@
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#ifdef __CLION_IDE__
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
asm volatile("trap;");
}
#define printf host_device_printf
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#endif
namespace deep_gemm {
template <typename FuncT>
struct PatternVisitor {
FuncT func;
__device__ __host__
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
__device__ __host__
auto operator [](const uint32_t& i) {
return func(i);
}
};
template <typename T>
__device__ __host__ T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ __host__ T align(T a, T b) {
return ceil_div(a, b) * b;
}
template <typename T>
__device__ __host__ constexpr T constexpr_align(T a, T b) {
return constexpr_ceil_div(a, b) * b;
}
template <typename T>
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : constexpr_gcd(b, a % b);
}
template<typename T>
__forceinline__ __device__ void swap(T& a, T& b) {
T temp = a;
a = b;
b = temp;
}
__forceinline__ __device__ uint32_t get_sm_idx() {
uint32_t sm_idx;
asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
return sm_idx;
}
__forceinline__ __device__ uint32_t get_lane_idx() {
uint32_t lane_id;
asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
uint32_t ret;
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
float4 ret;
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) {
uint4 ret;
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(ptr));
return ret;
}
__device__ __forceinline__ float ld_shared(const float* ptr) {
float ret;
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(ptr), "f"(val));
}
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(ptr), "r"(val));
}
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(ptr), "r"(x), "r"(y), "r"(z), "r"(w));
}
template <typename old_t>
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
return *reinterpret_cast<int*>(&bf16x2);
}
} // namespace `deep_gemm`

View File

@@ -1,363 +0,0 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include "mma_utils.cuh"
#include "scheduler.cuh"
#include "tma_utils.cuh"
#include "utils.cuh"
namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_wgrad_gemm_kernel(uint32_t shape_k,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_scales_b,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SMEM_SCALES_B_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE = ceil_div(SMEM_SCALES_B_SIZE_PER_STAGE, 128U) * 128U;
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
const uint32_t shape_k_scales = ceil_div(shape_k, BLOCK_K);
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_b));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_scales_a[kNumStages];
float* smem_scales_b[kNumStages];
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages + 1];
Barrier* empty_barriers[kNumStages + 1];
// Fill shared memory pointers
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)
+ i * SMEM_SCALES_A_SIZE_PER_STAGE);
smem_scales_b[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE)
+ i * ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE);
}
// Fill barriers
DG_STATIC_ASSERT(sizeof(Barrier) % sizeof(float) == 0, "Misaligned barriers");
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages
* (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + ALIGNED_SMEM_SCALES_B_SIZE_PER_STAGE));
#pragma unroll
for (int i = 0; i < kNumStages + 1; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
empty_barriers[i] = barrier_start_ptr + kNumStages + 1 + i;
}
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "To many TMA multicast");
if (threadIdx.x == kNumMathThreads) {
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (int i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
full_barriers[kNumStages]->init(1);
empty_barriers[kNumStages]->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
auto launch_k_iterations = [&](const auto& func) {
if constexpr (kNumLastStages == 0) {
for (int k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (int k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(num_iterations - 1, NotDivisibleK{});
}
};
// Register reconfigurations
constexpr int kNumTMARegisters = 40;
constexpr int kNumMathRegisters = 232;
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<GemmType::Normal, SHAPE_N, BLOCK_M, BLOCK_N, 1, kNumTMAMulticast, kIsTMAMulticastOnA>(SHAPE_M);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (threadIdx.x == kNumMathThreads) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
// Issue TMA A
auto& full_barrier = *full_barriers[s];
int k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, m_block_idx * BLOCK_M, num_tma_multicast_a);
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
k_idx / BLOCK_K, num_tma_multicast_a);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, n_block_idx * BLOCK_N, num_tma_multicast_b);
tma_copy(&tensor_map_scales_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_b[s], n_block_idx * BLOCK_N, k_idx / BLOCK_K, num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE + SMEM_SCALES_B_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
});
// Issue TMA D
empty_barriers[kNumStages]->wait((scheduler.current_iter + 1) & 1);
auto& full_barrier = *full_barriers[kNumStages];
tma_copy(&tensor_map_d, reinterpret_cast<uint64_t*>(&full_barrier),
smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M, 1);
full_barrier.arrive_and_expect_tx(SMEM_D_SIZE);
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
// Empty barrier arrival
auto empty_barrier_arrive = [&](int s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
constexpr int WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
float2 scales_b[WGMMA::kNumAccum / 4];
// Launch MMAs
launch_k_iterations([&](int k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr int kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (int s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Read A scales
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
// Commit WGMMA instructions
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
#pragma unroll
for (int k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[s] + (math_wg_idx * WGMMA::M + m_offset) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[s] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
// Read B scales at the first warpgroup wave
if (local_idx == 0) {
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_scales_b[s] + i * 8 + col_idx * 2));
__syncwarp();
}
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive(s);
// Promote with scales
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
const float &scale_b_0 = scales_b[i].x;
const float &scale_b_1 = scales_b[i].y;
shifted_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
}
}
}
// Wait last TMA store to be finished
if (k_iter == 0 and scheduler.current_iter > 0) {
if (threadIdx.x == 0) {
cute::tma_store_wait<0>();
empty_barriers[kNumStages]->arrive();
}
__syncwarp();
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
empty_barrier_arrive(s);
}
});
// Wait TMA D arrivals
full_barriers[kNumStages]->wait(scheduler.current_iter & 1);
// Accumulate to D shared memory
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + r_0) * BLOCK_N + col_idx * 2);
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + r_1) * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
float2 d_0 = ld_shared(smem_d_0 + i * 4);
st_shared(smem_d_0 + i * 4, {d_0.x + shifted_accum[i * 4 + 0], d_0.y + shifted_accum[i * 4 + 1]});
float2 d_1 = ld_shared(smem_d_1 + i * 4);
st_shared(smem_d_1 + i * 4, {d_1.x + shifted_accum[i * 4 + 2], d_1.y + shifted_accum[i * 4 + 3]});
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Use TMA store to write back to global memory
if (threadIdx.x == 0) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_d, n_block_idx * BLOCK_N, m_block_idx * BLOCK_M);
cute::tma_store_arrive();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false && "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,3 @@
#pragma once
// TODO: add implement

View File

@@ -0,0 +1,601 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_sfa,
const __grid_constant__ CUtensorMap tensor_map_sfb,
const __grid_constant__ CUtensorMap tensor_map_c,
const __grid_constant__ CUtensorMap tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
DG_STATIC_ASSERT(std::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
// Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M;
constexpr uint32_t kNumTMAStoreStages = 2;
constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t);
constexpr uint32_t kNumUTCCPAlignedElems = 128;
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad);
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// 2-CTA MMA
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t STORE_BLOCK_M = std::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2;
// Real tensor memory size and offsets
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_sfb);
cute::prefetch_tma_descriptor(&tensor_map_d);
if constexpr (kWithAccumulation)
cute::prefetch_tma_descriptor(&tensor_map_c);
}
// Data on shared memory (layout as ordered below)
cd_dtype_t* smem_cd[kNumTMAStoreStages];
cutlass::float_e4m3_t* smem_a[kNumStages];
cutlass::float_e4m3_t* smem_b[kNumStages];
uint32_t* smem_sfa[kNumStages];
uint32_t* smem_sfb[kNumStages];
// Fill D/A/B pointers
#pragma unroll
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
}
// Fill SFA/SFB
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_sfa[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
smem_sfb[i] = reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
}
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) +
kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
full_barriers[i]->init(1);
empty_barriers[i]->init(1);
// Arrive only at the leader CTA
with_sf_full_barriers[i]->init(kNumMulticast * 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
// Arrive at all CTAs
tmem_full_barriers[i]->init(1);
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA>(shape_m, shape_n, grouped_layout);
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
uint32_t phase = 0;
auto launch_k_iterations = [&](const auto& func) {
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t num_iterations = ceil_div(current_shape_k, kNumStages * BLOCK_K);
const uint32_t num_last_stages = ceil_div(current_shape_k, BLOCK_K) % kNumStages;
// TODO: refactor here
if (num_last_stages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, k_iter == num_iterations - 1, num_last_stages);
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter, phase ^= 1)
func(k_iter, DivisibleK{}, false, num_last_stages);
func(num_iterations - 1, NotDivisibleK{}, true, num_last_stages), phase ^= 1;
}
};
auto dispatch_accum_stage_idx = [&](uint32_t accum_stage_idx, const auto& func) {
DG_STATIC_ASSERT(1 <= kNumEpilogueStages and kNumEpilogueStages <= 2,
"Too many epilogue stages, please modify the Python heuristic as well");
accum_stage_idx == 0 ? func(0) : func(1);
};
// Dispatch warps into different roles
if (warp_idx == 0) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), KGroupedIndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), KGroupedIndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_block_idx = k_iter * kNumStages + s;
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), KGroupedIndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
if (cute::elect_one_sync()) {
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], k_a_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, 1>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_a_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, 1>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
}
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
tma_copy<BLOCK_M, 1, 0, 1>(&tensor_map_sfa, full_barriers[s], smem_sfa[s], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad)));
tma_copy<BLOCK_N, 1, 0, 1>(&tensor_map_sfb, full_barriers[s], smem_sfb[s], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, KGroupedIndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx));
num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t);
}
// Arrive at full barriers
if (cute::elect_one_sync())
full_barriers[s]->arrive_and_expect_tx(num_arrival_bytes);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait(phase ^ 1);
if (cute::elect_one_sync())
full_barriers[s]->arrive();
}
});
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
// TODO: refactor `UMMA_M` calculation
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e4m3_t, cutlass::float_e4m3_t,
float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto sf_desc = make_sf_desc(nullptr);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = make_umma_desc<kMajorA, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
// Wait tensor memory empty barrier arrival
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s, bool do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[s]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[s]->wait(phase);
tcgen05_after_thread_sync();
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) {
using cute_utccp_t = std::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
// SFA and SFB copy
// TODO: process shared memory descriptor by addition
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[s] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[s] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
__syncwarp();
// Issue UMMA in the leader CTA
using cute_mma_t = std::conditional_t<kNumMulticast == 1,
cute::SM100_MMA_MXF8F6F4_SS <cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>,
cute::SM100_MMA_MXF8F6F4_2x1SM_SS<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
cutlass::float_ue8m0_t, UMMA_M, UMMA_N, kMajorA, kMajorB>>;
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, s);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, s);
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, cutlass::float_e4m3_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
a_desc.lo = advance_umma_desc_lo<kMajorA, BLOCK_M, kSwizzleAMode, cutlass::float_e4m3_t>(a_desc_base_lo, w * LAYOUT_AD_M * BLOCK_K, k * UMMA_K);
cute_mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_iter > 0 or s > 0 or k > 0,
runtime_instr_desc,
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
kTmemStartColOfSFB);
}
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(s, is_last_iter and s == kNumInnerStages - 1);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
with_sf_full_barriers[s]->wait(phase);
empty_barrier_arrive(s, false);
}
});
});
}
} else if (warp_idx == 2) {
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type, bool is_last_iter, uint32_t num_last_stages) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
const uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : num_last_stages;
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait TMA arrival
full_barriers[s]->wait(phase);
// Transpose for UTCCP at certain stages
const uint32_t sf_stage_in_group_idx = (k_iter * kNumStages + s) % kNumSFStagesPerLoad;
if (sf_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[s] + i * kNumUTCCPAlignedElems);
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[s] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Arrive
with_sf_full_barriers[s]->arrive(0u);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait(phase);
with_sf_full_barriers[s]->arrive(0u);
}
});
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
// Epilogue warp groups
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
dispatch_accum_stage_idx(scheduler.current_iter % kNumEpilogueStages, [&](uint32_t accum_stage_idx) {
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Flush TMA stores
// NOTES: for the first store, we have to flush all previous TMA,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Iterate over M waves
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
const uint32_t iter_idx = w * kNumStores + s;
if (iter_idx >= kNumTMAStoreStages) {
if (epilogue_thread_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
}
// The pipeline stage
const auto tma_stage_idx = iter_idx % kNumTMAStoreStages;
const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), KGroupedIndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * LAYOUT_AD_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (std::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumEpilogueThreads).sync();
if (epilogue_thread_idx == 0) {
using cute_tma_t = std::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
}
});
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
// TODO: do we actually need this?
if (epilogue_thread_idx == 0)
cute::tma_store_wait<0>();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
// TODO: do we need 2 SM allocation?
if (epilogue_warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync
// TODO: optimize it by another round of barrier waits
if constexpr (kNumMulticast > 1)
cute::cluster_sync();
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,532 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
GemmType kGemmType, typename cd_dtype_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_d,
const __grid_constant__ CUtensorMap tensor_map_sfa) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t kNumMWaves = BLOCK_M / LAYOUT_AD_M;
constexpr uint32_t kNumTMAStoreStages = 2;
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
DG_STATIC_ASSERT(BLOCK_M % LAYOUT_AD_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_M == kNumEpilogueThreads, "Invalid block M");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
const auto shape_k_scales = ceil_div(shape_k, BLOCK_K);
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// 2-CTA MMA
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t STORE_BLOCK_M = std::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M and BLOCK_M % LAYOUT_AD_M == 0, "Only support tensor memory layout A/D");
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
// Share memory sizes
// NOTES: do not use `LOAD_BLOCK_M` for SFA, as we need full SFA for promotion
constexpr bool kMustUseUniformedSFB = (BLOCK_K % BLOCK_N == 0);
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
// Must have 2 epilogue stages
constexpr uint32_t kNumEpilogueStages = 2;
// Real tensor memory size and offsets
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == 0) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
}
// Data on shared memory (layout as ordered below)
cd_dtype_t* smem_cd[kNumTMAStoreStages];
cutlass::float_e4m3_t* smem_a[kNumStages];
cutlass::float_e4m3_t* smem_b[kNumStages];
float* smem_sfa[kNumStages];
// Fill D/A/B pointers
#pragma unroll
for (uint32_t i = 0; i < kNumTMAStoreStages; ++ i)
smem_cd[i] = reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<cutlass::float_e4m3_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
}
// Fill SFA/SFB
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i)
smem_sfa[i] = reinterpret_cast<float*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) +
kNumStages * SMEM_SFA_SIZE_PER_STAGE);
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + kNumEpilogueStages * 2);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (threadIdx.x == 0) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads / 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
// Arrive at all CTAs
tmem_full_barriers[i]->init(1);
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumEpilogueThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
cutlass::arch::fence_barrier_init();
} else if (threadIdx.x >= 32 and threadIdx.x < 64) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// For pipeline unrolling
struct DivisibleK {};
struct NotDivisibleK {};
const uint32_t num_iterations = ceil_div(shape_k, kNumStages * BLOCK_K);
auto launch_k_iterations = [=](const auto& func) {
if constexpr (kNumLastStages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{});
} else {
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
func(k_iter, DivisibleK{});
func(num_iterations - 1, NotDivisibleK{});
}
};
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA>(shape_m, shape_n, grouped_layout);
// Register configurations
constexpr uint32_t kNumNonEpilogueRegisters = 64;
constexpr uint32_t kNumEpilogueRegisters = 216;
DG_STATIC_ASSERT(kNumNonEpilogueRegisters * kNumNonEpilogueThreads + kNumEpilogueRegisters * kNumEpilogueThreads <= 65535, "Too many registers");
// Dispatch warps into different roles
if (warp_idx == 0) {
// Adjust registers
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::K)>(
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_block_idx = k_iter * kNumStages + s;
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_b_idx = scheduler.get_global_idx<(kMajorB == cute::UMMA::Major::MN)>(
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
if (cute::elect_one_sync()) {
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], k_idx, m_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, kNumMulticast>(&tensor_map_a, full_barriers[s], smem_a[s], m_idx, k_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], k_b_idx, n_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, kNumMulticast>(&tensor_map_b, full_barriers[s], smem_b[s], n_idx, k_b_idx);
// Issue SFA TMA
tma_copy<BLOCK_M, 1, 0, kNumMulticast>(
&tensor_map_sfa, full_barriers[s],
smem_sfa[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_k_scales, 1, k_block_idx));
}
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE;
if (is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
if (is_leader_cta and cute::elect_one_sync())
full_barriers[s]->arrive();
}
});
}
} else if (warp_idx == 1 and is_leader_cta) {
// Adjust registers
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
// TODO: refactor `UMMA_M` calculation
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Launch MMAs
launch_k_iterations([&](uint32_t k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s) {
// Wait TMA full
auto iter_idx = scheduler.current_iter * num_iterations + k_iter;
full_barriers[s]->wait(iter_idx & 1);
// Wait tensor memory empty
auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages;
auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_stage_phase ^ 1);
// Issue UMMA in the leader CTA
if (s < kNumInnerStages) {
using cute_mma_t = std::conditional_t<kNumMulticast == 1,
cute::SM100_MMA_F8F6F4_SS, cute::SM100_MMA_F8F6F4_2x1SM_SS>;
tcgen05_after_thread_sync();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[s], 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
auto a_desc = make_umma_desc<kMajorA, BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[s], w * LAYOUT_AD_M, k * UMMA_K);
cute_mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k > 0,
runtime_instr_desc);
}
}
tcgen05_before_thread_sync();
}
// Commit to the TMA empty and tensor memory full barrier
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
}
});
}
} else if (warp_idx < kNumNonEpilogueThreads / 32) {
// Adjust registers
cutlass::arch::warpgroup_reg_dealloc<kNumNonEpilogueRegisters>();
} else if (warp_idx >= kNumNonEpilogueThreads / 32) {
// Adjust registers
cutlass::arch::warpgroup_reg_alloc<kNumEpilogueRegisters>();
// Epilogue warp groups
const auto epilogue_thread_idx = threadIdx.x - kNumNonEpilogueThreads;
const auto epilogue_thread_idx_in_warpgroup = epilogue_thread_idx % 128;
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
const auto epilogue_warpgroup_idx = epilogue_thread_idx / 128;
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
constexpr uint32_t kNumElemsPerLDTM = 16;
DG_STATIC_ASSERT(kNumElemsPerLDTM == 16 and BLOCK_N % kNumElemsPerLDTM == 0 and BLOCK_K % kNumElemsPerLDTM == 0, "Invalid LDTM width");
// SFB stuffs
uint32_t num_former_iters = BLOCK_N, num_full_iters = BLOCK_N;
if constexpr (not kMustUseUniformedSFB) {
num_former_iters = min(BLOCK_N, BLOCK_K - ((n_block_idx * BLOCK_N) % BLOCK_K));
num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N);
}
num_former_iters /= kNumElemsPerLDTM, num_full_iters /= kNumElemsPerLDTM;
const auto sfb_offset = scheduler.get_global_idx<true>(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx);
const auto sfb_ptr = sfb + (sfb_offset + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales;
// Launch promotion
float accum[BLOCK_N] = {0};
launch_k_iterations([&](uint32_t k_iter, auto type) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
DG_STATIC_ASSERT(kNumInnerStages != 0, "Invalid number of inner stages");
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s) {
// Load SFB
float sf_0 = 0, sf_1 = 0;
if (s < kNumInnerStages) {
const auto k_block_idx = k_iter * kNumStages + s;
sf_0 = __ldg(sfb_ptr + k_block_idx);
sf_1 = num_former_iters < num_full_iters ? __ldg(sfb_ptr + k_block_idx + shape_k_scales) : 0;
}
// Wait UMMA arrival
auto iter_idx = scheduler.current_iter * num_iterations + k_iter;
auto accum_stage_idx = (iter_idx * kNumStages + s) % kNumEpilogueStages;
auto accum_stage_phase = ((iter_idx * kNumStages + s) / kNumEpilogueStages) & 1;
tmem_full_barriers[accum_stage_idx]->wait(accum_stage_phase);
tcgen05_after_thread_sync();
// Commit to the TMA empty barrier for all CTAs after loading SFA
float sfa = s < kNumInnerStages ? ld_shared(smem_sfa[s] + epilogue_thread_idx) : 0;
sf_0 *= sfa, sf_1 *= sfa;
__syncwarp();
if (lane_idx < kNumMulticast)
empty_barriers[s]->arrive(lane_idx);
__syncwarp();
// Do promotion like the SM90 kernel
if (s < kNumInnerStages) {
uint32_t values[kNumElemsPerLDTM];
#pragma unroll
for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerLDTM; ++ i) {
// Load from tensor memory
cute::SM100_TMEM_LOAD_32dp32b16x::copy(
accum_stage_idx * kNumMWaves * BLOCK_N + epilogue_warpgroup_idx * BLOCK_N + i * kNumElemsPerLDTM,
values[ 0], values[ 1], values[ 2], values[ 3],
values[ 4], values[ 5], values[ 6], values[ 7],
values[ 8], values[ 9], values[10], values[11],
values[12], values[13], values[14], values[15]);
cutlass::arch::fence_view_async_tmem_load();
// Promote
const auto sf = (kMustUseUniformedSFB or i < num_former_iters) ? sf_0 : sf_1;
#pragma unroll
for (uint32_t j = 0; j < kNumElemsPerLDTM; ++ j)
accum[i * kNumElemsPerLDTM + j] += *reinterpret_cast<float*>(&values[j]) * sf;
}
}
// Commit to the tensor memory empty barrier (only at the leader CTA)
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
});
// Flush TMA stores
// NOTES: for the first store, we have to flush all previous TMA,
// as we don't share pipeline stages between two blocks
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
// Write shared memory
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Epilogue store and addition
// Issue every swizzled atom and pipeline: store shared, add C, and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
if (s >= kNumTMAStoreStages) {
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
}
// The pipeline stage
const auto tma_stage_idx = s % kNumTMAStoreStages;
const auto m_idx = scheduler.get_global_idx<(kGemmType != GemmType::MGroupedContiguous)>(shape_m, BLOCK_M, m_block_idx);
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
const auto local_smem_cd = smem_cd[tma_stage_idx] + epilogue_warpgroup_idx * STORE_BLOCK_M * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
// NOTES: if you want to do accumulation, please notice that you need two accumulation barriers
const auto offset = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup;
if constexpr (std::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
st_shared(smem_ptr,
*reinterpret_cast<uint32_t*>(&accum[offset + 0]),
*reinterpret_cast<uint32_t*>(&accum[offset + 1]),
*reinterpret_cast<uint32_t*>(&accum[offset + 2]),
*reinterpret_cast<uint32_t*>(&accum[offset + 3]));
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and std::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
st_shared(smem_ptr,
cast_into_bf16_and_pack(accum[offset + 0], accum[offset + 1]),
cast_into_bf16_and_pack(accum[offset + 2], accum[offset + 3]),
cast_into_bf16_and_pack(accum[offset + 4], accum[offset + 5]),
cast_into_bf16_and_pack(accum[offset + 6], accum[offset + 7]));
}
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(STORE_BLOCK_M, epilogue_warpgroup_idx).sync();
if (epilogue_thread_idx_in_warpgroup == 0) {
cute::SM90_TMA_STORE_2D::copy(
&tensor_map_d, local_smem_cd,
n_idx, m_idx + epilogue_warpgroup_idx * STORE_BLOCK_M);
cute::tma_store_arrive();
}
}
}
// Flush all stages in the pipeline to make TMA stores visible to the next kernel
// TODO: do we actually need this?
if (epilogue_thread_idx_in_warpgroup == 0)
cute::tma_store_wait<0>();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
// TODO: do we need 2 SM allocation?
if (epilogue_warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
// To safely deconstruct all barriers, we need a cluster sync
// TODO: optimize it by another round of barrier waits
if constexpr (kNumMulticast > 1)
cute::cluster_sync();
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100a/sm_101a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,3 @@
#pragma once
// TODO: add implement

View File

@@ -0,0 +1,3 @@
#pragma once
// TODO: add implement

View File

@@ -10,13 +10,14 @@
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include "mma_utils.cuh"
#include "scheduler.cuh"
#include "tma_utils.cuh"
#include "utils.cuh"
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd>
__device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_iterations, const auto& func, uint32_t num_former_iters) {
if (num_former_iters == kNumFormerIters) {
@@ -28,59 +29,58 @@ __device__ __host__ void outer_launch_k_iterations(const auto& inner_launch_k_it
outer_launch_k_iterations<kNumFormerIters + kGap, kGap, kEnd>(inner_launch_k_iterations, func, num_former_iters);
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t BLOCK_N_PADDING,
uint32_t kSwizzleDMode,
uint32_t kNumGroups, uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreadsPerGroup,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
GemmType kGemmType>
__global__ void __launch_bounds__(get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M), 1)
fp8_gemm_kernel(float* scales_b, int* grouped_layout,
uint32_t shape_m,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_scales_a,
const __grid_constant__ CUtensorMap tensor_map_d) {
__global__ void __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ CUtensorMap tensor_map_a,
const __grid_constant__ CUtensorMap tensor_map_b,
const __grid_constant__ CUtensorMap tensor_map_d,
const __grid_constant__ CUtensorMap tensor_map_sfa) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * (BLOCK_N + BLOCK_N_PADDING) * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SCALES_A_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SHAPE_K_SCALES = ceil_div(SHAPE_K, BLOCK_K);
static constexpr uint32_t SMEM_SCALES_B_SIZE = ceil_div<uint32_t>(SHAPE_K_SCALES * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)) * sizeof(Barrier);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K);
const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
// Configs
constexpr uint32_t kFullKOfAllStages = kNumStages * BLOCK_K;
constexpr uint32_t kNumThreads = get_num_threads_per_sm<kNumTMAThreads, kNumMathThreadsPerGroup>(BLOCK_M);
constexpr uint32_t kNumMathThreads = kNumThreads - kNumTMAThreads;
constexpr uint32_t kNumIterations = ceil_div(SHAPE_K, kFullKOfAllStages);
const uint32_t num_iterations = ceil_div(shape_k, kFullKOfAllStages);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_id();
const uint32_t lane_idx = get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (threadIdx.x == kNumMathThreads) {
// NOTES: `reinterpret_cast` must be here, or NVRTC will fail
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_a));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_b));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_scales_a));
// `tensor_map_d` is only used in swizzling mode
// For the `kSwizzleDMode == 0 and BLOCK_N_PADDING == 0` case, it will be treated as padding mode
if constexpr (kSwizzleDMode > 0)
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_sfa));
cute::prefetch_tma_descriptor(reinterpret_cast<const cute::TmaDescriptor*>(&tensor_map_d));
}
__syncwarp();
@@ -92,8 +92,8 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
__nv_fp8_e4m3* smem_a[kNumStages];
__nv_fp8_e4m3* smem_b[kNumStages];
float* smem_scales_a[kNumStages];
float* smem_scales_b;
float* smem_sfa[kNumStages];
float* smem_sfb;
// TMA Barrier for both divisible and non-divisible cases
Barrier* full_barriers[kNumStages];
@@ -104,12 +104,12 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
for (uint32_t i = 0; i < kNumStages; ++ i) {
smem_a[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
smem_b[i] = reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
smem_scales_a[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SCALES_A_SIZE_PER_STAGE);
smem_sfa[i] = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + i * SMEM_SFA_SIZE_PER_STAGE);
}
smem_scales_b = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE));
smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE));
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_scales_b) + SMEM_SCALES_B_SIZE);
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i] = barrier_start_ptr + i;
@@ -129,7 +129,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Make initialized barrier visible in async proxy
cutlass::arch::fence_view_async_shared();
(kNumTMAMulticast > 1) ? cutlass::arch::fence_barrier_init() : void();
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
@@ -140,7 +140,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
struct NotDivisibleK {};
struct SkipComputation {};
struct NotSkipComputation {};
auto launch_k_iterations = [](const auto& func, bool skip_computation, uint32_t num_former_iters) {
auto launch_k_iterations = [=](const auto& func, bool skip_computation, uint32_t num_former_iters) {
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
@@ -149,15 +149,15 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Otherwise, the compiler must know the dynamic variable `num_former_iters`'s real value
outer_launch_k_iterations<0, kGap, kEnd>([=](const auto& func, auto num_former_iters_type) {
if (skip_computation) {
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{}, SkipComputation{}, num_former_iters_type);
} else if (SHAPE_K % kFullKOfAllStages == 0) {
for (uint32_t k_iter = 0; k_iter < kNumIterations; ++ k_iter)
} else if (shape_k % kFullKOfAllStages == 0) {
for (uint32_t k_iter = 0; k_iter < num_iterations; ++ k_iter)
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
} else {
for (uint32_t k_iter = 0; k_iter < kNumIterations - 1; ++ k_iter)
for (uint32_t k_iter = 0; k_iter < num_iterations - 1; ++ k_iter)
func(k_iter, DivisibleK{}, NotSkipComputation{}, num_former_iters_type);
func(kNumIterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
func(num_iterations - 1, NotDivisibleK{}, NotSkipComputation{}, num_former_iters_type);
}
}, func, kShouldOptimize ? num_former_iters : 0);
};
@@ -168,7 +168,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, SHAPE_N, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, grouped_layout);
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA>(shape_m, shape_n, grouped_layout);
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
@@ -180,7 +180,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto _, auto __) {
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K;
constexpr uint32_t kNumInnerStages = kHasDivisibleStages ? kNumStages : kNumLastStages;
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
@@ -194,30 +194,31 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Wait consumer release
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
// Issue TMA A
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[s];
uint32_t k_idx = k_iter * kFullKOfAllStages + s * BLOCK_K;
tma_copy(&tensor_map_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_a[s], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx),
smem_a[s], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a);
tma_copy(&tensor_map_scales_a, reinterpret_cast<uint64_t*>(&full_barrier),
smem_scales_a[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx(SHAPE_K_SCALES, 1, k_idx / BLOCK_K),
tma_copy(&tensor_map_sfa, reinterpret_cast<uint64_t*>(&full_barrier),
smem_sfa[s], m_block_idx * BLOCK_M,
scheduler.get_global_idx<kWithGroupOffsetA>(shape_k_scales, 1, k_idx / BLOCK_K),
num_tma_multicast_a);
// Issue TMA B
tma_copy(&tensor_map_b, reinterpret_cast<uint64_t*>(&full_barrier),
smem_b[s], k_idx, scheduler.get_global_idx<false>(SHAPE_N, BLOCK_N, n_block_idx, m_block_idx),
smem_b[s], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SCALES_A_SIZE_PER_STAGE);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter + 1) & 1);
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter + 1) & 1);
full_barriers[s]->arrive();
}
}, false, 0);
@@ -227,7 +228,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s)
empty_barriers[s]->wait((scheduler.current_iter * kNumIterations + 1) & 1);
empty_barriers[s]->wait((scheduler.current_iter * num_iterations + 1) & 1);
}
}
} else {
@@ -235,33 +236,33 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / kNumMathThreadsPerGroup, 0);
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_STATIC_ASSERT(SHAPE_N % 8 == 0, "Invalid shape N");
DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0);
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
if constexpr (not kMustUseUniformedScaleB) {
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
num_full_iters = min(SHAPE_N - n_block_idx * BLOCK_N, BLOCK_N) / 8;
num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8;
}
uint32_t num_scales_b = SHAPE_K_SCALES * (num_former_iters >= num_full_iters ? 1 : 2);
uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2);
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto num_previous_lines = scheduler.get_global_idx<false>(ceil_div(SHAPE_N, BLOCK_K), 0, 0, m_block_idx);
auto local_scales_b = scales_b + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * SHAPE_K_SCALES;
auto num_previous_lines = scheduler.get_global_idx<true>(ceil_div(shape_n, BLOCK_K), 0, 0, m_block_idx);
auto local_sfb = sfb + (num_previous_lines + ((n_block_idx * BLOCK_N) / BLOCK_K)) * shape_k_scales;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_scales_b; i += kNumMathThreads - 32)
st_shared(smem_scales_b + i, __ldg(local_scales_b + i));
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
st_shared(smem_sfb + i, __ldg(local_sfb + i));
}
cutlass::arch::NamedBarrier(kNumMathThreads).sync();
// Accumulation for WGMMA or CUDA promotion
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * get_num_math_warpgroups(BLOCK_M);
constexpr uint32_t WAVE_BLOCK_M = WGMMA::M * (BLOCK_M <= 64 ? 1 : 2);
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
@@ -279,19 +280,18 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
launch_k_iterations([&](uint32_t k_iter, auto divisible_type, auto skip_type, auto _) {
constexpr bool kSkipComputation = std::is_same_v<decltype(skip_type), SkipComputation>;
constexpr bool kHasDivisibleStages = std::is_same_v<decltype(divisible_type), DivisibleK>;
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 :
(kHasDivisibleStages ? kNumStages : (SHAPE_K % kFullKOfAllStages) / BLOCK_K);
constexpr uint32_t kNumInnerStages = kSkipComputation ? 0 : (kHasDivisibleStages ? kNumStages : kNumLastStages);
#pragma unroll
for (uint32_t s = 0; s < kNumInnerStages; ++ s) {
// Read B scales
float scale_b_0 = ld_shared(smem_scales_b + k_iter * kNumStages + s), scale_b_1;
float scale_b_0 = ld_shared(smem_sfb + k_iter * kNumStages + s), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_scales_b + k_iter * kNumStages + s + SHAPE_K_SCALES);
scale_b_1 = ld_shared(smem_sfb + k_iter * kNumStages + s + shape_k_scales);
// Wait TMA arrivals
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
@@ -300,8 +300,8 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_scales_a[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_scales_a[s] + r_1 + m_offset);
auto scale_a_0 = ld_shared(smem_sfa[s] + r_0 + m_offset);
auto scale_a_1 = ld_shared(smem_sfa[s] + r_1 + m_offset);
// Commit WGMMA instructions
#pragma unroll
@@ -347,7 +347,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Wait unaligned cases
#pragma unroll
for (uint32_t s = kNumInnerStages; s < kNumStages; ++ s) {
full_barriers[s]->wait((scheduler.current_iter * kNumIterations + k_iter) & 1);
full_barriers[s]->wait((scheduler.current_iter * num_iterations + k_iter) & 1);
empty_barrier_arrive(s);
}
}, not scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M), num_former_iters);
@@ -360,8 +360,6 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions");
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
DG_STATIC_ASSERT(static_cast<uint32_t>(kSwizzleDMode > 0) + static_cast<uint32_t>(BLOCK_N_PADDING > 0) <= 1,
"Swizzling and padding are not compatible");
// Wait last TMA store to be finished
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
@@ -403,9 +401,7 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling, just padding
// NOTES: padding must be zero for BF16 output
DG_STATIC_ASSERT(BLOCK_N_PADDING == 0, "Padding must be zero for BF16 output");
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * (BLOCK_N + BLOCK_N_PADDING) + i * 8);
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
}
// NOTES: only 16 lanes' addresses are used
@@ -421,13 +417,14 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
// Use TMA store to write back to global memory
// TODO: compatible with FP32 output
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
DG_STATIC_ASSERT(kNumMathThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset,
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx));
cute::tma_store_arrive();
}
__syncwarp();
@@ -441,4 +438,4 @@ fp8_gemm_kernel(float* scales_b, int* grouped_layout,
}; // namespace deep_gemm
#pragma clang diagnostic pop
#pragma clang diagnostic pop

View File

@@ -0,0 +1,139 @@
#pragma once
#include <cstdint>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm {
// NOTES: the two kernels below always pack the K dimension
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
extern __shared__ uint32_t smem_buffer[];
// Shapes and strides
constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u);
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
const auto tma_aligned_mn = align<uint64_t>(mn, kNumTMAAlignedElems);
// Shift into the group
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
// Load FP32 SFs
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
const auto num_values = in_block_mn * SF_K;
const auto num_uint4 = num_values / 4;
#pragma unroll
for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
const auto& [x, y, z, w] = __ldg(reinterpret_cast<uint4*>(local_sf) + i);
st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
}
// Fill unaligned values as well
if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx));
__syncthreads();
// Pack into UE8M0 and store
#pragma unroll
for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) {
const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN;
// Load shared memory
uint32_t values[4];
#pragma unroll
for (uint32_t j = 0; j < 4; ++ j) {
const auto sf_k_idx = sf_k_pack_idx * 4 + j;
values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
}
// Pack and store
uint32_t packed = 0;
packed |= (values[0] >> 23u);
packed |= (values[1] >> 15u);
packed |= (values[2] >> 7u);
packed |= (values[3] << 1u);
if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn)
out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed;
}
}
template <uint32_t kNumGroups, uint32_t kNumThreads,
uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) {
// Always packing the K dimension
// NOTES: should also assert `mn % 4 == 0` at launch
DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes");
DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes");
// Shapes and strides
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
const auto in_block_mn_uint4 = in_block_mn / 4;
const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K);
// Shift into the right block along MN
sf += blockIdx.x * BLOCK_MN;
out += blockIdx.x * BLOCK_MN;
// Each warp is responsible for a packed row
const auto warp_idx = threadIdx.x / 32;
const auto lane_idx = get_lane_idx();
const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
if (warp_idx >= in_block_packed_sf_k)
return;
// Make an offset on the input
uint32_t input_offset = 0;
if constexpr (kNumGroups > 1) {
// Load each group's size
DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups");
uint32_t group_ks[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i) {
const auto group_idx = lane_idx * 4 + i;
group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0;
}
__syncwarp();
// Make the offset
sf_k = 0;
auto sum_packed_sf_k = 0;
#pragma unroll
for (uint32_t i = 0; i < kNumGroups; ++ i) {
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4);
sf_k += sf_k_in_group;
sum_packed_sf_k += ceil_div(sf_k_in_group, 4u);
if (packed_sf_k_idx < sum_packed_sf_k)
break;
if (const auto remainder = sf_k_in_group % 4; remainder > 0)
input_offset += 4 - remainder;
}
}
for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
// Load
uint4 values[4];
#pragma unroll
for (uint32_t j = 0; j < 4; ++ j) {
values[j] = make_uint4(0, 0, 0, 0);
if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
values[j] = __ldg(reinterpret_cast<uint4*>(sf + sf_k_idx * mn) + mn_idx);
}
// Pack and store
uint4 packed;
packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u);
packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u);
packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u);
packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u);
reinterpret_cast<uint4*>(out + packed_sf_k_idx * mn)[mn_idx] = packed;
}
}
} // namespace deep_gemm

View File

@@ -1,163 +0,0 @@
#pragma once
#include "utils.cuh"
namespace deep_gemm {
enum class GemmType {
Normal,
GroupedContiguous,
GroupedMasked
};
#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t SHAPE_N, uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumNBlocks = ceil_div(SHAPE_N, BLOCK_N),
uint32_t kNum1DBlocksPerGroup = 16>
struct Scheduler {
int current_iter = -1;
uint32_t num_aligned_m_blocks;
// For normal GEMM
// Maybe not used in the masked grouped GEMM
uint32_t num_blocks;
uint32_t num_blocks_in_group;
bool is_peer_cta_alive = true;
// For grouped GEMM
int* grouped_layout;
// Only used for masked layout
uint32_t curr_group_idx, curr_cumsum;
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m,
int* grouped_layout = nullptr) {
num_aligned_m_blocks = ceil_div(shape_m, BLOCK_M);
if constexpr (kGemmType == GemmType::Normal) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
} else if (kGemmType == GemmType::GroupedContiguous) {
num_blocks = num_aligned_m_blocks * kNumNBlocks;
this->grouped_layout = grouped_layout;
} else if (kGemmType == GemmType::GroupedMasked) {
curr_group_idx = curr_cumsum = 0;
this->grouped_layout = grouped_layout;
}
}
// ReSharper disable once CppNotAllPathsReturnValue
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
if constexpr (kGemmType == GemmType::Normal) {
return true;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + curr_group_idx);
}
}
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::GroupedMasked) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::GroupedContiguous, "Invalid Gemm type");
if constexpr (kIsTMAMulticastOnA) {
return true;
} else {
auto group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
auto peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
}
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& num_m_blocks, const uint32_t& block_idx,
uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumTMAMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
auto primary_num_blocks = kIsTMAMulticastOnA ? kNumNBlocks : num_m_blocks;
auto secondary_num_blocks = kIsTMAMulticastOnA ? num_m_blocks : kNumNBlocks;
auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
auto group_idx = block_idx / num_blocks_per_group;
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
// Fix unaligned TMA multicast
if (kNumTMAMulticast > 1 and num_blocks_in_group % 2 != 0) {
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
num_blocks_in_group = num_blocks_in_group ^ 1;
} else {
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
first_block_idx += num_blocks_in_group ^ 1;
num_blocks_in_group = 1;
}
}
// Convert to final M/N block indices
if constexpr (kIsTMAMulticastOnA) {
m_block_idx = in_group_idx / num_blocks_in_group;
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
} else {
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
n_block_idx = in_group_idx / num_blocks_in_group;
}
}
template <bool kIgnoreGroupedForGroupedContiguous=true>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t& shape_dim, const uint32_t& block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx=0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if constexpr (kGemmType == GemmType::GroupedContiguous) {
auto offset = kIgnoreGroupedForGroupedContiguous ? 0 : max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M));
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::GroupedMasked) {
return curr_group_idx * shape_dim + block_idx * block_size;
}
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * gridDim.x + blockIdx.x;
if constexpr (kGemmType == GemmType::GroupedMasked) {
uint32_t num_m_blocks;
while (true) {
// End of the task
if (curr_group_idx == kNumGroups)
return false;
// Within the current group
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + curr_group_idx)), BLOCK_M);
auto current_m_block_cumsum = curr_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * kNumNBlocks)
break;
// Move to check the next group
curr_group_idx ++, curr_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx(num_m_blocks, next_block_idx - curr_cumsum * kNumNBlocks, m_block_idx, n_block_idx);
} else {
if (next_block_idx >= num_blocks)
return false;
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
is_peer_cta_alive = kNumNBlocks % kNumTMAMulticast == 0 or // Always aligned on N (constant bypass)
num_aligned_m_blocks % kNumTMAMulticast == 0 or // Always aligned on M (constant bypass)
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
get_swizzled_block_idx(num_aligned_m_blocks, next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@@ -1,19 +0,0 @@
#pragma once
#include "utils.cuh"
namespace deep_gemm {
// TODO: move this function to other files
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, uint64_t* barrier_ptr, void* smem_ptr,
int32_t const& crd_0, int32_t const& crd_1, uint32_t num_tma_multicast) {
constexpr auto cache_hint = static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL);
if (num_tma_multicast == 1) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, barrier_ptr, cache_hint, smem_ptr, crd_0, crd_1);
} else if (cute::block_rank_in_cluster() == 0) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, barrier_ptr, (1 << num_tma_multicast) - 1, cache_hint, smem_ptr, crd_0, crd_1);
}
}
} // namespace deep_gemm

View File

@@ -1,34 +0,0 @@
#pragma once
#ifdef __CLION_IDE__
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
asm volatile("trap;");
}
#define printf host_device_printf
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, reason) static_assert(cond, reason)
#endif
template <typename T>
__device__ __host__ constexpr T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : constexpr_gcd(b, a % b);
}

View File

@@ -1,2 +0,0 @@
from .compiler import get_nvcc_compiler, build, NVCCCompiler, NVRTCCompiler
from .runtime import Runtime

View File

@@ -1,284 +0,0 @@
import functools
import hashlib
import os
import re
import subprocess
import time
import uuid
from typing import Any, Dict, List, Tuple, Type
import cuda.bindings
import cuda.bindings.nvrtc as nvrtc
from torch.utils.cpp_extension import CUDA_HOME
from . import interleave_ffma
from .runtime import Runtime, RuntimeCache
runtime_cache = RuntimeCache()
def hash_to_hex(s: str) -> str:
md5 = hashlib.md5()
md5.update(s.encode('utf-8'))
return md5.hexdigest()[0:12]
@functools.lru_cache(maxsize=None)
def get_jit_include_dir() -> str:
return os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'include')
@functools.lru_cache(maxsize=None)
def get_deep_gemm_version() -> str:
md5 = hashlib.md5()
# Update include directories
include_dir = os.path.join(get_jit_include_dir(), 'deep_gemm')
assert os.path.exists(include_dir), f'Cannot find GEMM include directory {include_dir}'
for filename in filter(lambda x: x.endswith('.cuh'), sorted(os.listdir(include_dir))):
with open(os.path.join(include_dir, filename), 'rb') as f:
md5.update(f.read())
# Update `interleave_ffma.py`
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'interleave_ffma.py'), 'rb') as f:
md5.update(f.read())
return md5.hexdigest()[0:12]
@functools.lru_cache(maxsize=None)
def get_nvcc_compiler() -> Tuple[str, str]:
paths = []
if os.getenv('DG_JIT_NVCC_COMPILER'):
paths.append(os.getenv('DG_JIT_NVCC_COMPILER'))
paths.append(os.path.join(CUDA_HOME, 'bin', 'nvcc'))
# Try to find the first available NVCC compiler
least_version_required = '12.3'
version_pattern = re.compile(r'release (\d+\.\d+)')
for path in paths:
if os.path.exists(path):
command = [path, '--version']
result = subprocess.run(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, text=True)
match = version_pattern.search(result.stdout)
version = match.group(1)
assert match, f'Cannot get the version of NVCC compiler {path}'
assert version >= least_version_required, f'NVCC {path} version {version} is lower than {least_version_required}'
return path, version
raise RuntimeError('Cannot find any available NVCC compiler')
@functools.lru_cache(maxsize=None)
def get_default_user_dir():
if 'DG_JIT_CACHE_DIR' in os.environ:
path = os.getenv('DG_JIT_CACHE_DIR')
os.makedirs(path, exist_ok=True)
return path
return os.path.join(os.path.expanduser('~'), '.deep_gemm')
@functools.lru_cache(maxsize=None)
def get_tmp_dir():
return os.path.join(get_default_user_dir(), 'tmp')
@functools.lru_cache(maxsize=None)
def get_cache_dir():
return os.path.join(get_default_user_dir(), 'cache')
def make_tmp_dir():
tmp_dir = get_tmp_dir()
os.makedirs(tmp_dir, exist_ok=True)
return tmp_dir
def put(path, data):
# Write and do POSIX atomic replace
tmp_file_path = os.path.join(make_tmp_dir(), f'file.tmp.{str(uuid.uuid4())}.{hash_to_hex(path)}')
with open(tmp_file_path, 'wb' if isinstance(data, bytes) else 'w') as f:
f.write(data)
os.replace(tmp_file_path, path)
class Compiler:
@classmethod
def signature(cls) -> str:
pass
@staticmethod
def __version__() -> Tuple[int, int]:
pass
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None:
pass
@staticmethod
def flags() -> List[str]:
cpp_standard = int(os.getenv('DG_JIT_OVERRIDE_CPP_STANDARD', 20))
return [f'-std=c++{cpp_standard}',
'--ptxas-options=--register-usage-level=10' +
(',--verbose' if 'DG_JIT_PTXAS_VERBOSE' in os.environ else ''),
# Suppress some unnecessary warnings, such as unused variables for certain `constexpr` branch cases
'--diag-suppress=39,161,174,177,186,940']
@staticmethod
def include_dirs() -> List[str]:
return [get_jit_include_dir()]
@classmethod
def build(cls, name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
# Compiler flags
flags = cls.flags()
# Build signature
enable_sass_opt = cls.__version__() <= (12, 8) and not int(os.getenv('DG_JIT_DISABLE_FFMA_INTERLEAVE', 0))
signature = f'{name}$${get_deep_gemm_version()}$${cls.signature()}$${flags}$${enable_sass_opt}$${code}'
name = f'kernel.{name}.{hash_to_hex(signature)}'
path = os.path.join(get_cache_dir(), name)
# Check runtime cache or file system hit
global runtime_cache
cached_runtime = runtime_cache.get(path, runtime_cls, name, kwargs)
if cached_runtime is not None:
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Using cached JIT runtime {name} during build')
return cached_runtime
# Compile into a temporary CU file
os.makedirs(path, exist_ok=True)
cubin_path = os.path.join(path, 'kernel.cubin')
tmp_cubin_path = os.path.join(make_tmp_dir(), f'nvcc.tmp.{str(uuid.uuid4())}.{hash_to_hex(cubin_path)}.cubin')
start_time = time.time()
cls.compile(name, code, tmp_cubin_path)
end_time = time.time()
elapsed_time = end_time - start_time
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Compilation of JIT runtime {name} took {elapsed_time:.2f} seconds.')
# Interleave FFMA reuse
if enable_sass_opt:
interleave_ffma.process(tmp_cubin_path)
# Atomic replace files
os.replace(tmp_cubin_path, cubin_path)
# Put cache and return
runtime = runtime_cache.get(path, runtime_cls, name, kwargs, force_enable_cache=True)
assert runtime is not None
return runtime
class NVCCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
_, version = get_nvcc_compiler()
major, minor = map(int, version.split('.'))
return major, minor
@classmethod
def signature(cls) -> str:
return f'{get_nvcc_compiler()[0]}+{cls.__version__()}'
@classmethod
def flags(cls) -> List[str]:
cxx_flags = ['-fPIC', '-O3', '-fconcepts', '-Wno-deprecated-declarations', '-Wno-abi']
return [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'-gencode=arch=compute_90a,code=sm_90a',
'-cubin', '-O3', '--expt-relaxed-constexpr', '--expt-extended-lambda',
f'--compiler-options={",".join(cxx_flags)}']
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None:
# Write the code
path = os.path.join(get_cache_dir(), name)
src_path = os.path.join(path, 'kernel.cu')
put(src_path, code)
command = [get_nvcc_compiler()[0],
src_path, '-o', target_path,
*cls.flags()]
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with command {command}')
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode != 0:
print(f'NVCC compilation failed: stdout: {result.stdout}, stderr: {result.stderr}')
assert False, f'Failed to compile {src_path}'
class NVRTCCompiler(Compiler):
@staticmethod
def __version__() -> Tuple[int, int]:
res, major, minor = nvrtc.nvrtcVersion()
if res != nvrtc.nvrtcResult.NVRTC_SUCCESS:
# Failed to get the actual NVRTC version, use cuda-bindings version instead
major, minor = map(int, cuda.bindings.__version__.split('.')[:2])
return major, minor
@classmethod
def signature(cls) -> str:
return f'nvrtc+{cls.__version__()}'
@staticmethod
def include_dirs() -> List[str]:
if CUDA_HOME is None:
raise RuntimeError('CUDA_HOME is required for NVRTC compilation')
return [get_jit_include_dir(), os.path.join(CUDA_HOME, 'include')]
@classmethod
def flags(cls) -> List[str]:
flags = [*super().flags(), *[f'-I{d}' for d in cls.include_dirs()],
'--gpu-architecture=sm_90a', '-default-device']
# NOTES: PCH is vital for compilation speed
if cls.__version__() >= (12, 8):
flags += ['--pch']
if int(os.getenv('DG_JIT_DEBUG', 0)):
flags += ['--pch-verbose=true']
return flags
@classmethod
def compile(cls, name: str, code: str, target_path: str) -> None:
# Create program
code_bytes = bytes(code, 'utf-8')
result, program = nvrtc.nvrtcCreateProgram(
code_bytes, bytes(name, 'utf-8'), 0, [], [])
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to create program: {result}'
# Compile
options = [bytes(flag, 'utf-8') for flag in cls.flags()]
if int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_JIT_PRINT_COMPILER_COMMAND', 0)):
print(f'Compiling JIT runtime {name} with options: {options}')
compile_result = nvrtc.nvrtcCompileProgram(program, len(options), options)[0]
# Print compiler log
if int(os.getenv('DG_JIT_DEBUG', 0)) or compile_result != nvrtc.nvrtcResult.NVRTC_SUCCESS:
result, log_size = nvrtc.nvrtcGetProgramLogSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log size: {result}'
log_bytes = bytes(log_size)
result = nvrtc.nvrtcGetProgramLog(program, log_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get program log: {result}'
print(f'Compiler log: {log_bytes.decode("utf-8")}')
# Exit if failed
assert compile_result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to compile program: {compile_result}'
# Create CUBIN
result, cubin_size = nvrtc.nvrtcGetCUBINSize(program)
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN size: {result}'
cubin_bytes = bytes(cubin_size)
result = nvrtc.nvrtcGetCUBIN(program, cubin_bytes)[0]
assert result == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to get CUBIN: {result}'
# Write into the file system
put(target_path, cubin_bytes)
# Destroy handler
assert nvrtc.nvrtcDestroyProgram(program)[0] == nvrtc.nvrtcResult.NVRTC_SUCCESS, f'Failed to destroy program: {result}'
def build(name: str, code: str, runtime_cls: Type[Runtime], kwargs: Dict[str, Any] = None) -> Runtime:
compiler_cls = NVRTCCompiler if int(os.getenv('DG_JIT_USE_NVRTC', 0)) else NVCCCompiler
return compiler_cls.build(name, code, runtime_cls, kwargs)

View File

@@ -1,137 +0,0 @@
import argparse
import mmap
import os
import re
import subprocess
from torch.utils.cpp_extension import CUDA_HOME
def run_cuobjdump(file_path):
command = [f'{CUDA_HOME}/bin/cuobjdump', '-sass', file_path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
return result.stdout
def extract_ffma(sass):
lines = sass.splitlines()
collected = []
current = []
arch_name, func_name = 'N/A', 'N/A'
skip_next_line = False
for line in lines:
if 'code for' in line:
arch_name = line.lstrip().lstrip('code for ').rstrip()
elif 'Function :' in line:
func_name = line.lstrip().lstrip('Function :').rstrip()
elif 'FFMA' in line:
current.append(line)
skip_next_line = True
elif skip_next_line:
current.append(line)
skip_next_line = False
else:
if len(current) >= 16:
assert len(current) % 2 == 0
collected.append((f'{arch_name}::{func_name}', current))
current = []
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
print(f'Found {len(collected)} FFMA segments')
return collected
def extract_hex_from_line(line):
match = re.search(r'/\*\s*(0x[0-9a-fA-F]+)\s*\*/', line)
assert match
return int(match.group(1), 16)
def validate(m, offset, le_bytes, num_lines):
assert len(le_bytes) == num_lines // 2
assert m[offset:offset + 16] == le_bytes[0]
for i in range(1, num_lines // 2):
if m[offset + i * 16:offset + i * 16 + 16] != le_bytes[i]:
return False
return True
def parse_registers(line):
line = re.sub(r'/\*.*?\*/', '', line)
line = line.replace(';', '')
tokens = line.strip().split(',')
registers = []
for token in tokens:
token = token.strip()
words = token.split()
for word in words:
if word.startswith('R'):
reg = word.split('.')[0]
registers.append(reg)
return registers
def modify_segment(m, name, ffma_lines):
num_lines = (len(ffma_lines) * 9 // 16) // 2 * 2
assert num_lines % 2 == 0
le_bytes, new_le_bytes = [], []
reused_list = []
dst_reg_set = set()
last_reused, last_dst_reg = False, ''
num_changed = 0
for i in range(num_lines // 2):
dst_reg = parse_registers(ffma_lines[i * 2])[-2]
low_line, high_line = ffma_lines[i * 2], ffma_lines[i * 2 + 1]
low_hex, high_hex = extract_hex_from_line(low_line), extract_hex_from_line(high_line)
le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
reused = (high_hex & 0x0800000000000000) != 0
if reused:
is_first_occurred = dst_reg not in dst_reg_set
if is_first_occurred or (last_reused and dst_reg == last_dst_reg):
# Modify the `reuse` and `yield` bits
assert high_hex & 0x0800200000000000, f'{hex(high_hex)}'
high_hex ^= 0x0800200000000000
reused = False
num_changed += 1
else:
reused_list.append(i)
dst_reg_set.add(dst_reg)
new_le_bytes.append(low_hex.to_bytes(8, 'little') + high_hex.to_bytes(8, 'little'))
last_reused, last_dst_reg = reused, dst_reg
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
print(f' > segment `{name}` new reused list ({num_changed} changed): {reused_list}')
# Find the offset
offsets = []
offset = m.find(le_bytes[0])
while offset != -1:
offsets.append(offset)
offset = m.find(le_bytes[0], offset + 1)
offsets = list(filter(lambda x: validate(m, x, le_bytes, num_lines), offsets))
# Replace with `new_le_bytes`
for offset in offsets:
for i in range(num_lines // 2):
m[offset + i * 16:offset + i * 16 + 16] = new_le_bytes[i]
def process(path):
if int(os.getenv('DG_JIT_PRINT_REG_REUSE', 0)):
print(f'Processing {path}')
output = run_cuobjdump(path)
segments = extract_ffma(output)
with open(path, 'r+b') as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_WRITE)
for segment in segments:
modify_segment(mm, *segment)
mm.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Interleave FFMA reg reuse')
parser.add_argument('--so', help='Path to the SO file')
args = parser.parse_args()
process(args.so)

View File

@@ -1,105 +0,0 @@
import os
import subprocess
import time
import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict, Optional, Type
from torch.utils.cpp_extension import CUDA_HOME
class Runtime:
def __init__(self, path: str) -> None:
self.path = path
self.lib = None
self.kernel = None
assert self.is_path_valid(self.path)
@staticmethod
def is_path_valid(path: str) -> bool:
# Exists and is a directory
if not os.path.exists(path) or not os.path.isdir(path):
return False
# Contains all necessary files
files = ['kernel.cubin']
return all(os.path.exists(os.path.join(path, file)) for file in files)
@staticmethod
def generate(kwargs: Dict[str, Any]) -> str:
raise NotImplemented
@staticmethod
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
raise NotImplemented
def __call__(self, **kwargs) -> cbd.CUresult:
# Load CUBIN
if self.kernel is None:
start_time = time.time_ns()
# Load CUBIN
path = bytes(os.path.join(self.path, 'kernel.cubin'), 'utf-8')
result, self.lib = cbd.cuLibraryLoadFromFile(path, [], [], 0, [], [], 0)
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load library: {result}'
# Extract the kernel name
# TODO: use `cuda-bindings` API to do this (requires at least 12.8)
command = [f'{CUDA_HOME}/bin/cuobjdump', '-symbols', path]
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
assert result.returncode == 0
illegal_names = ['vprintf', '__instantiate_kernel', '__internal', '__assertfail']
check_illegal = lambda line: any([name in line for name in illegal_names])
kernel_names = [line.split()[-1] for line in result.stdout.splitlines()
if line.startswith('STT_FUNC') and not check_illegal(line)]
assert len(kernel_names) == 1, f'Too many kernels in the library: {kernel_names}'
# Load kernel from the library
result, self.kernel = cbd.cuLibraryGetKernel(self.lib, bytes(kernel_names[0], encoding='utf-8'))
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to load kernel: {result}'
end_time = time.time_ns()
elapsed_time = (end_time - start_time) / 1e6
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Loading JIT runtime {self.path} took {elapsed_time:.2f} ms.')
# noinspection PyArgumentList
return self.launch(self.kernel, kwargs)
def __del__(self) -> None:
if self.lib is not None:
res = cbd.cuLibraryUnload(self.lib)[0]
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to unload library {self.path}: {res}')
class RuntimeCache:
def __init__(self) -> None:
self.cache = {}
def __setitem__(self, path: str, runtime: Runtime) -> None:
self.cache[path] = runtime
def get(self, path: str, runtime_cls: Type[Runtime],
name: str = '', kwargs: Dict[str, Any] = None,
force_enable_cache: bool = False) -> Optional[Runtime]:
# In Python runtime
if path in self.cache:
return self.cache[path]
# Already compiled
use_cache = force_enable_cache or not int(os.getenv('DG_JIT_DISABLE_CACHE', 0))
if use_cache and os.path.exists(path) and Runtime.is_path_valid(path):
# Print heuristic for the first time
if name and (int(os.getenv('DG_JIT_DEBUG', 0)) or int(os.getenv('DG_PRINT_CONFIGS', 0))):
simplified_kwargs = dict()
for key, value in kwargs.items() if kwargs is not None else dict().items():
value = f'torch.Tensor<{value.dtype}>' if isinstance(value, torch.Tensor) else value
value = f'cuda.bindings.driver.CUtensorMap' if isinstance(value, cbd.CUtensorMap) else value
simplified_kwargs[key] = value
print(f'Put kernel {name} with {simplified_kwargs} into runtime cache')
runtime = runtime_cls(path)
self.cache[path] = runtime
return runtime
return None

View File

@@ -1,14 +0,0 @@
from .gemm import gemm_fp8_fp8_bf16_nt
from .m_grouped_gemm import (
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
m_grouped_gemm_fp8_fp8_bf16_nt_masked
)
from .wgrad_gemm import (
wgrad_gemm_fp8_fp8_fp32_nt,
k_grouped_wgrad_gemm_fp8_fp8_fp32_nt
)
from .utils import (
ceil_div, set_num_sms, get_num_sms,
get_col_major_tma_aligned_tensor,
get_m_alignment_for_contiguous_layout
)

View File

@@ -1,242 +0,0 @@
import math
import torch
from functools import lru_cache
from typing import Tuple
from ..jit import build
from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .utils import get_num_sms, ceil_div, get_col_major_tma_aligned_tensor, get_m_alignment_for_contiguous_layout
def is_tma_multicast_legal(shape_dim: int, block_dim: int, num_tma_multicast: int, num_sms: int,
require_divisible: bool = False) -> bool:
divisible = ceil_div(shape_dim, block_dim) % num_tma_multicast == 0 or not require_divisible
return divisible and num_sms % num_tma_multicast == 0
def get_swizzle_mode(block_n: int) -> int:
elem_size = 2
for mode_bytes in (128, 64, 32):
if (block_n * elem_size) % mode_bytes == 0:
return mode_bytes
return 0
def get_block_n_padding_for_smem_d(block_n: int) -> int:
# NOTES: padding is for solving bank conflicts, but wastes shared memory space
elem_size, requirement = 2, (4, 8)
bank_stride = (block_n * elem_size) // 4
padding = (requirement[0] - bank_stride) % requirement[1]
return (((padding + requirement[1]) if padding < 0 else padding) * 4) // elem_size
def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k: int = 128,
is_fp32_out: bool = False, is_wgrad: bool = False) -> Tuple[int, int, int]:
assert block_k == 128
# Try swizzle first, as it does not waste shared memory
swizzle_mode = get_swizzle_mode(block_n)
block_n_padding = get_block_n_padding_for_smem_d(
block_n) if swizzle_mode == 0 else 0
# NOTES: `scales_b` in a total manner or per-stage manner
smem_d = block_m * (block_n + block_n_padding) * (4 if is_fp32_out else 2)
smem_a_per_stage = block_m * block_k
smem_scales_a_per_stage = block_m * 4
smem_b_per_stage = block_n * block_k
smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0
smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0
smem_barrier = (num_stages + int(is_wgrad)) * 8 * 2
smem_size = 0
smem_size += smem_d
smem_size += num_stages * smem_a_per_stage
smem_size += num_stages * smem_scales_a_per_stage
smem_size += num_stages * smem_b_per_stage
smem_size += num_stages * smem_scales_b_per_stage
smem_size += ceil_div(smem_scales_b * (1 if block_k % block_n == 0 else 2), 8) * 8
smem_size += smem_barrier
# Swizzle and padding are not compatible
assert int(swizzle_mode > 0) + int(block_n_padding > 0) <= 1
return smem_size, swizzle_mode, block_n_padding
@lru_cache(maxsize=None)
def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int,
is_grouped_contiguous: bool = False, is_grouped_masked: bool = False,
is_fp32_out: bool = False, is_wgrad: bool = False) -> \
Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]]:
if not is_grouped_contiguous:
block_ms = (64, 128, ) + ((256, ) if not is_fp32_out else ())
else:
block_ms = (get_m_alignment_for_contiguous_layout(), )
block_ns = tuple(range(16, 129, 8)) + ((136, 152, ) if is_wgrad else (144, 160, ))
# Avoid bank conflicts for FP32 output
if is_fp32_out:
block_ns = [x for x in block_ns if x % 16 == 8]
fix_wave_saturate = lambda x: num_sms if x == 0 else x
get_num_waves = lambda bm, bn: (ceil_div(ceil_div(m, bm) * ceil_div(n, bn) * num_groups, num_sms) if bm else None)
get_last_wave_util = lambda bm, bn: fix_wave_saturate((ceil_div(m, bm) * ceil_div(n, bn) * num_groups) % num_sms)
# Decide block sizes by waves
best_block_m, best_block_n = None, None
for block_m in block_ms:
# NOTES: the block sizes cannot be too large, so at least one dim less than 128
for block_n in filter(lambda bn: block_m <= 128 or bn <= 128, block_ns):
success = False
num_waves, best_num_waves = get_num_waves(block_m, block_n), get_num_waves(best_block_m, best_block_n)
if best_block_m is None or best_block_n is None:
success = True
elif num_waves < best_num_waves:
success = True
elif num_waves == best_num_waves:
# Check last wave utilization
util = get_last_wave_util(block_m, block_n)
best_util = get_last_wave_util(best_block_m, best_block_n)
success = util > best_util
if util == best_util:
# Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n
# Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m
# Case 3: different for both `block_m` and `block_n`, `block_n` larger is better
success |= block_m != best_block_m and block_n > best_block_n
best_block_m, best_block_n = (block_m, block_n) if success else (best_block_m, best_block_n)
assert best_block_m is not None and best_block_n is not None
# Always pick the longest one
# NOTES: for double B scales, the best number of stages may be reduced
best_num_stages, best_smem_config, sm90_capacity = None, None, 232448
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (8, 7, 6, 5, 4, 3, 2, 1)))
if 128 % best_block_n != 0 and 128 // math.gcd(128, best_block_n) <= 4:
# Unrolling both stages and `num_former_iters` will cause large code size
stage_candidates = tuple(filter(lambda s: s <= max(k // 128, 1), (4, 3, 2, 1)))
for num_stages in stage_candidates:
best_smem_config = get_smem_config(num_stages, k, best_block_m, best_block_n, is_fp32_out=is_fp32_out, is_wgrad=is_wgrad)
if best_smem_config[0] <= sm90_capacity:
best_num_stages = num_stages
break
assert best_smem_config is not None
assert best_num_stages is not None
# Decide the number of TMA multicasts and whether broadcast on A
best_tma_multicast_config = (1, True)
# Try to multicast on the larger block side first
# NOTES: currently, grouped masked GEMM only supports multicast on A and requires the number of blocks in the N-direction to be even
is_multicast_legal = {
'A': is_tma_multicast_legal(n, best_block_n, 2, num_sms, is_grouped_masked),
'B': is_tma_multicast_legal(m, best_block_m, 2, num_sms) and not is_grouped_masked,
}
for i in ('A', 'B') if best_block_m > best_block_n else ('B', 'A'):
if m >= 512 and is_multicast_legal[i]:
best_tma_multicast_config = (2, i == 'A')
break
# Recompute the minimal number of SMs required
# NOTES: less L2 cache usage and less GPU frequency drop
num_waves = get_num_waves(best_block_m, best_block_n)
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves)
num_min_sms = ceil_div(num_min_sms, best_tma_multicast_config[0]) * best_tma_multicast_config[0]
assert num_min_sms <= num_sms
return num_min_sms, best_block_m, best_block_n, best_num_stages, best_tma_multicast_config, best_smem_config
def gemm_fp8_fp8_bf16_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor) -> None:
"""
Perform a normal GEMM with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 8.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
the second element is an FP32 128x128 scaling tensor for RHS of shape `[⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m, n]`, representing the result.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
n, k_ = rhs.shape
m_, n_ = out.shape
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and k > 0
assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == (ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
# LHS scales must be transposed for TMA loads, but not for RHS scales
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
return
# K must be aligned to 128
aligned_k = ceil_div(k, 128) * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(m, n, k, 1, num_sms)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
kwargs = {
# Templated arguments
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': aligned_k,
'NUM_GROUPS': 1,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
# Runtime arguments
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': torch.empty(0, dtype=torch.int32, device=out.device),
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(kwargs)
runtime = build('gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(**kwargs)

View File

@@ -1,205 +0,0 @@
import torch
from typing import Tuple
from ..jit import build
from .gemm import get_best_configs
from .runtime import (
FP8GemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .utils import ceil_div, get_col_major_tma_aligned_tensor, get_num_sms
def m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, m_indices: torch.Tensor) -> None:
"""
Perform a grouped GEMM (contiguous format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
On the M axis, inputs are grouped into several batches, of which batch sizes aligned to
`get_m_alignment_for_contiguous_layout()` (128).
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m_sum, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m_sum, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`,
the second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[m_sum, n]`, representing the result.
m_indices: a tensor of shape `[m_sum]` with type `torch.int`.
`m_indices[i]` records the group which the i-th row of the LHS belongs to,
which means that the i-th row of the LHS matrix will be multiplied with `rhs[m_indices[i]]`.
Values of `m_indices` in every-m-alignment-block must also be the same.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
num_groups, n, k_ = rhs.shape
m_, n_ = out.shape
m__ = m_indices.numel()
# Type and shape checks
assert m == m_ == m__ and k == k_ and n == n_
assert lhs_scales.shape == (m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert m_indices.dtype == torch.int32
assert lhs.is_contiguous() and rhs.is_contiguous()
assert out.is_contiguous() and m_indices.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Do nothing if `m` is zero
if m == 0:
return
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, k, 1, num_sms, is_grouped_contiguous=True)
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedContiguous, lhs, m, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedContiguous, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedContiguous, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedContiguous, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
# Templated arguments
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedContiguous,
# Runtime arguments
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': m_indices,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(**kwargs)
def m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor, masked_m: torch.Tensor, expected_m: int) -> None:
"""
Perform a grouped GEMM (masked format) with FP8 inputs and BF16 output, with 1x128 LHS scaling and 128x128 RHS scaling.
Requirements:
LHS, RHS, RHS scaling factors, and output tensors must be in contiguous format.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling tensor requires a TMA-aligned transposed format, if your input does not match the requirement,
this function will do a transposing with a set of slow PyTorch operations.
Moreover, this alignment requirement is different with the contiguous-format kernel, as we require that each batch
should be separately transposed.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, m_max, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[num_groups, m_max, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[num_groups, n, k]`.
The second element is an FP32 128x128 scaling tensor for RHS of shape `[num_groups, ⌈n / 128⌉, ⌈k / 128⌉]`.
out: the BF16 output tensor of shape `[num_groups, m_max, n]`, representing the result.
masked_m: a tensor of shape `[num_groups]`, `masked_m[i]` records actual rows of the `lhs[i]` matrix to compute
in the i-th group.
expected_m: a value hint (which is a value on CPU) for the M expectation of each batch,
correctly setting this value may lead to better performance.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
num_groups, m, k = lhs.shape
num_groups_, n, k_ = rhs.shape
num_groups__, m_, n_ = out.shape
num_groups___ = masked_m.numel()
# Type and shape checks
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert lhs_scales.shape == (num_groups, m, ceil_div(k, 128))
assert rhs_scales.shape == (num_groups, ceil_div(n, 128), ceil_div(k, 128))
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.bfloat16
assert masked_m.dtype == torch.int32
assert lhs.is_contiguous() and rhs.is_contiguous()
assert out.is_contiguous() and masked_m.is_contiguous()
# LHS scales must be transposed for TMA load, but not for RHS scales
lhs_scales = get_col_major_tma_aligned_tensor(lhs_scales)
assert rhs_scales.is_contiguous()
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
expected_m, n, k, num_groups, num_sms, is_grouped_masked=True)
# Extra checks for TMA store
if num_groups > 1 and m > block_m:
assert m % block_m == 0, f'For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m})'
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(GemmType.GroupedMasked, lhs, m, k, k, block_m, block_k, num_groups)
tensor_map_b = make_2d_tma_b_desc(GemmType.GroupedMasked, rhs, n, k, k, block_n, block_k, num_groups)
tensor_map_d = make_2d_tma_d_desc(GemmType.GroupedMasked, out, m, n, n, block_m, block_n, num_groups, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.GroupedMasked, lhs_scales, m, k, block_m, block_k, num_groups)
kwargs = {
# Templated arguments
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': k,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'SWIZZLE_D_MODE': smem_config[1],
'BLOCK_N_PADDING': smem_config[2],
'NUM_GROUPS': num_groups,
'NUM_STAGES': num_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
'GEMM_TYPE': GemmType.GroupedMasked,
# Runtime arguments
'SCALES_B': rhs_scales,
'GROUPED_LAYOUT': masked_m,
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
# Generate, build and run the kernel
code = FP8GemmRuntime.generate(kwargs)
runtime = build('m_grouped_gemm_fp8_fp8_bf16_nt', code, FP8GemmRuntime, kwargs)
runtime(**kwargs)

View File

@@ -1,318 +0,0 @@
import ctypes
import os
import enum
import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict, Tuple
from .utils import get_tma_aligned_size
from ..jit.runtime import Runtime
class GemmType(enum.Enum):
Normal = 0
GroupedContiguous = 1
GroupedMasked = 2
def __str__(self) -> str:
return {
0: 'Normal',
1: 'GroupedContiguous',
2: 'GroupedMasked',
}[self.value]
tmap_type_map: Dict[Any, str] = {
torch.int8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.int16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.int32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT32,
torch.int64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_INT64,
torch.uint8: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.uint16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT16,
torch.uint32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT32,
torch.uint64: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT64,
torch.float32: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT32,
torch.float16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_FLOAT16,
torch.bfloat16: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_BFLOAT16,
torch.float8_e4m3fn: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e4m3fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
torch.float8_e5m2fnuz: cbd.CUtensorMapDataType.CU_TENSOR_MAP_DATA_TYPE_UINT8,
}
swizzle_type_map = {
0: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE,
32: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_32B,
64: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_64B,
128: cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B,
}
def get_num_math_warpgroups(block_m: int) -> int:
return 1 if block_m == 64 else 2
def get_num_threads_per_sm(num_tma_threads: int, num_math_threads_per_group: int, block_m: int) -> int:
assert num_math_threads_per_group == 128, 'Only support 128 threads per math group'
return get_num_math_warpgroups(block_m) * num_math_threads_per_group + num_tma_threads
def make_2d_tma_copy_desc(t: torch.Tensor,
gmem_dims: Tuple[cbd.cuuint64_t, cbd.cuuint64_t], gmem_outer_stride: cbd.cuuint64_t,
smem_dims: Tuple[cbd.cuuint32_t, cbd.cuuint32_t],
swizzle_type: cbd.CUtensorMapSwizzle) -> cbd.CUtensorMap:
tensor_dtype = tmap_type_map[t.dtype]
res, tensor_map = cbd.cuTensorMapEncodeTiled(
tensor_dtype,
2,
t.data_ptr(),
gmem_dims,
(gmem_outer_stride,),
smem_dims,
(cbd.cuuint32_t(1), cbd.cuuint32_t(1)),
cbd.CUtensorMapInterleave.CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle_type,
cbd.CUtensorMapL2promotion.CU_TENSOR_MAP_L2_PROMOTION_L2_256B,
cbd.CUtensorMapFloatOOBfill.CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE,
)
if res != cbd.CUresult.CUDA_SUCCESS:
raise Exception(f'Failed to encode tensor map: {res}')
return tensor_map
def make_2d_tma_desc(t: torch.Tensor,
gmem_inner_dim: int, gmem_outer_dim: int, gmem_outer_stride: int,
smem_inner_dim: int, smem_outer_dim: int,
swizzle_type: cbd.CUtensorMapSwizzle = cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_128B) -> cbd.CUtensorMap:
gmem_dim = (cbd.cuuint64_t(gmem_inner_dim), cbd.cuuint64_t(gmem_outer_dim))
smem_dim = (cbd.cuuint32_t(smem_inner_dim), cbd.cuuint32_t(smem_outer_dim))
return make_2d_tma_copy_desc(t, gmem_dim, cbd.cuuint64_t(gmem_outer_stride * t.element_size()), smem_dim, swizzle_type)
def make_2d_tma_a_desc(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_k: int, m_stride: int,
block_m: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
shape_k, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
block_k, block_m)
def make_2d_tma_b_desc(gemm_type: GemmType, t: torch.Tensor,
shape_n: int, shape_k: int, n_stride: int,
block_n: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
return make_2d_tma_desc(t,
shape_k, shape_n * (num_groups if gemm_type != GemmType.Normal else 1), n_stride,
block_k, block_n)
def make_2d_tma_d_desc(gemm_type: GemmType, t: torch.Tensor,
shape_m: int, shape_n: int, m_stride: int,
block_m: int, block_n: int,
num_groups: int,
swizzle_mode: int) -> cbd.CUtensorMap:
# Swizzling requires the inner box dim to be less or equal than `kSwizzleDMode`
# bytes, so `BLOCK_N * sizeof(T) / kSwizzleDMode` TMA stores are required
return make_2d_tma_desc(t,
shape_n, shape_m * (num_groups if gemm_type == GemmType.GroupedMasked else 1), m_stride,
block_n if swizzle_mode == 0 else swizzle_mode // t.element_size(), block_m,
swizzle_type_map[swizzle_mode])
def make_2d_tma_scales_desc(gemm_type: GemmType, t: torch.Tensor,
shape_mn: int, shape_k: int,
block_mn: int, block_k: int,
num_groups: int) -> cbd.CUtensorMap:
# Make TMA aligned to 16 bytes
shape_mn = get_tma_aligned_size(shape_mn, t.element_size())
return make_2d_tma_desc(t,
shape_mn, (shape_k + block_k - 1) // block_k * (num_groups if gemm_type == GemmType.GroupedMasked else 1), shape_mn,
block_mn, 1,
cbd.CUtensorMapSwizzle.CU_TENSOR_MAP_SWIZZLE_NONE)
class FP8GemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path)
@staticmethod
def generate(kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&fp8_gemm_kernel<
{kwargs['N']},
{kwargs['K']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['BLOCK_N_PADDING']},
{kwargs['SWIZZLE_D_MODE']},
{kwargs['NUM_GROUPS']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'},
GemmType::{kwargs['GEMM_TYPE']}
>);
}};
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = kwargs['NUM_SMS']
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = kwargs['SMEM_SIZE']
config.hStream = kwargs['STREAM']
arg_values = (
kwargs['SCALES_B'].data_ptr(),
kwargs['GROUPED_LAYOUT'].data_ptr(),
kwargs['M'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_D'],
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
None,
None,
None,
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)
class FP8WGradGemmRuntime(Runtime):
def __init__(self, path: str) -> None:
super().__init__(path)
@staticmethod
def generate(kwargs: Dict[str, Any]) -> str:
code = f'''
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#include <string>
#endif
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <deep_gemm/fp8_wgrad_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&fp8_wgrad_gemm_kernel<
{kwargs['M']},
{kwargs['N']},
{kwargs['BLOCK_M']},
{kwargs['BLOCK_N']},
{kwargs['BLOCK_K']},
{kwargs['NUM_STAGES']},
{kwargs['NUM_LAST_STAGES']},
{kwargs['NUM_TMA_THREADS']},
{kwargs['NUM_MATH_THREADS_PER_GROUP']},
{kwargs['NUM_TMA_MULTICAST']},
{'true' if kwargs['IS_TMA_MULTICAST_ON_A'] else 'false'}
>);
}};
'''
if int(os.getenv('DG_JIT_DEBUG', 0)):
print(f'Generated FP8 WGrad GEMM code:\n{code}')
return code
# noinspection PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
num_tma_threads = 128
num_math_threads_per_group = 128
result = cbd.cuKernelSetAttribute(cbd.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
kwargs['SMEM_SIZE'], kernel, cbd.CUdevice(kwargs['DEVICE_INDEX']))[0]
assert result == cbd.CUresult.CUDA_SUCCESS, f'Failed to set max dynamic shared memory size: {result}'
attr_val = cbd.CUlaunchAttributeValue()
attr_val.clusterDim.x = kwargs['NUM_TMA_MULTICAST']
attr_val.clusterDim.y = 1
attr_val.clusterDim.z = 1
attr = cbd.CUlaunchAttribute()
attr.id = cbd.CUlaunchAttributeID.CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION
attr.value = attr_val
config = cbd.CUlaunchConfig()
config.numAttrs = 1
config.attrs = [attr]
config.gridDimX = kwargs['NUM_SMS']
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = get_num_threads_per_sm(num_tma_threads, num_math_threads_per_group, kwargs['BLOCK_M'])
config.blockDimY = 1
config.blockDimZ = 1
config.sharedMemBytes = kwargs['SMEM_SIZE']
config.hStream = kwargs['STREAM']
arg_values = (
kwargs['K'],
kwargs['TENSOR_MAP_A'],
kwargs['TENSOR_MAP_B'],
kwargs['TENSOR_MAP_SCALES_A'],
kwargs['TENSOR_MAP_SCALES_B'],
kwargs['TENSOR_MAP_D'],
)
arg_types = (
ctypes.c_uint32,
None,
None,
None,
None,
None,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)

View File

@@ -1,109 +0,0 @@
import torch
_num_sms = None
def set_num_sms(num_sms: int) -> None:
"""
Set the maximum SM count for all GEMM kernels to use.
Arguments:
num_sms: the desired maximum SM count for all GEMM kernels to use.
"""
global _num_sms
assert 0 < num_sms <= torch.cuda.get_device_properties(device='cuda').multi_processor_count
_num_sms = num_sms
def get_num_sms() -> int:
"""
Get the current maximum limit of SM count for all GEMM kernels to use.
If the count is never specified, the function will return the number of device SMs.
Returns:
Current maximum limit of SM count for all GEMM kernels to use.
"""
global _num_sms
if _num_sms is None:
_num_sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
return _num_sms
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def get_m_alignment_for_contiguous_layout():
"""
When we do a grouped GEMM in contiguous format, LHS are grouped into several batches along the M axis.
Since we deal with exactly one sub-matrix of RHS for each GEMM block, batch sizes above should align well
with GEMM block shape.
Returns:
Group-level alignment requirement for grouped contiguous layout, which is always 128.
"""
return 128
def get_tma_aligned_size(x: int, element_size: int) -> int:
"""
Global memory address of TMA must be 16-byte aligned.
Since we use column-major layout for the LHS scaling tensor,
the M-axis of the LHS scaling tensor needs to be padded to a multiple of 16 bytes.
Arguments:
x: original M-axis shape of the LHS scaling tensor.
element_size: element size of the LHS scaling tensor.
Returns:
M-axis shape of the LHS scaling tensor after padding.
"""
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return ceil_div(x, alignment) * alignment
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
"""
Returns TMA-aligned transposed format of the input tensor. `torch.transpose` will be called if necessary.
If the input tensor is already column-major layout and 16-byte aligned along the M axis
(thus meets the requirement of LHS scaling tensor in DeepGEMM), this function will do nothing.
Arguments:
x: usually the LHS scaling tensor in GEMM.
Returns:
The LHS scaling tensor of TMA-aligned transposed format.
"""
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dim() in (2, 3)
remove_dim = False
m, n = x.shape[-2], x.shape[-1]
aligned_m = get_tma_aligned_size(m, x.element_size())
if x.dim() == 2:
if x.stride(0) == 1 and x.stride(1) == aligned_m:
return x
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
# The last kernel gives a column-major TMA aligned layout
if x.stride(0) == aligned_m * n and x.stride(1) == 1 and x.stride(2) == aligned_m:
return x.squeeze(0) if remove_dim else x
# Normal layout requires transposing
aligned_x = torch.transpose(torch.empty((b, n, aligned_m), device=x.device, dtype=x.dtype), 1, 2)
aligned_x[:, :m, :] = x
aligned_x = aligned_x[:, :m, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x

View File

@@ -1,158 +0,0 @@
import torch
from typing import List, Tuple
from ..jit import build
from .runtime import (
FP8WGradGemmRuntime, GemmType,
make_2d_tma_a_desc, make_2d_tma_b_desc,
make_2d_tma_d_desc, make_2d_tma_scales_desc)
from .gemm import get_best_configs
from .utils import ceil_div, get_num_sms, get_col_major_tma_aligned_tensor, get_tma_aligned_size
def wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor):
"""
Perform a weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor.
Requirements:
LHS, RHS, and output tensors must be contiguous in dimension 1, i.e., stride(1) = 1.
The stride(0) of LHS and RHS must be a multiple of 16, and the stride(0) of output must be a multiple of 4.
RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensor require a TMA-aligned transposed format.
If your input does not match the requirement, this function will do a transposing with a set of slow PyTorch operations.
Arguments:
lhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[m, k]`,
the second element is an FP32 1x128 scaling tensor for LHS of shape `[m, ⌈k / 128⌉]`.
rhs: the first element is an FP8 tensor (typed `torch.float8_e4m3fn`) of shape `[n, k]`,
the second element is an FP32 1x128 scaling tensor for RHS of shape `[n, ⌈k / 128⌉]`.
out: the FP32 output tensor of shape `[m, n]`, which will be accumulated.
"""
lhs, lhs_scales = lhs
rhs, rhs_scales = rhs
m, k = lhs.shape
n, k_ = rhs.shape
m_, n_ = out.shape
# Type and shape checks
assert m == m_ and n == n_ and k == k_
assert n > 0 and m > 0
assert lhs_scales.shape == (m, ceil_div(k, 128)) or lhs_scales.shape == (ceil_div(k, 128), m)
assert rhs_scales.shape == (n, ceil_div(k, 128)) or rhs_scales.shape == (ceil_div(k, 128), n)
assert lhs.dtype == torch.float8_e4m3fn and lhs_scales.dtype == torch.float32
assert rhs.dtype == torch.float8_e4m3fn and rhs_scales.dtype == torch.float32
assert out.dtype == torch.float
assert lhs.stride(1) == 1 and out.stride(1) == 1 and rhs.stride(1) == 1
# LHS and RHS scales must be transposed for TMA load
# NOTES: `get_col_major_tma_aligned_tensor` may launch a kernel if not processed by previous kernels
def get_valid_scales(scales: torch.Tensor, mn: int):
if scales.shape == (ceil_div(k, 128), mn):
# For k-grouped GEMMs
scales = scales.permute(1, 0)
assert get_tma_aligned_size(mn, 4) == scales.stride(1) == mn
else:
scales = get_col_major_tma_aligned_tensor(scales)
return scales
lhs_scales = get_valid_scales(lhs_scales, m)
rhs_scales = get_valid_scales(rhs_scales, n)
# Do nothing if `k` is zero
if k == 0:
return
# K must be aligned to 128
aligned_k = ceil_div(k, 128) * 128
# Auto-tuning with compilation
num_sms = get_num_sms()
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = get_best_configs(
m, n, aligned_k, 1, num_sms, is_fp32_out=True, is_wgrad=True)
num_last_stages = ceil_div(k, 128) % num_stages
block_k = 128
num_tma_threads = 128
num_math_threads_per_group = 128
tensor_map_a = make_2d_tma_a_desc(GemmType.Normal, lhs, m, k, lhs.stride(0), block_m, block_k, 1)
tensor_map_b = make_2d_tma_b_desc(GemmType.Normal, rhs, n, k, rhs.stride(0), block_n, block_k, 1)
tensor_map_d = make_2d_tma_d_desc(GemmType.Normal, out, m, n, out.stride(0), block_m, block_n, 1, smem_config[1])
tensor_map_scales_a = make_2d_tma_scales_desc(GemmType.Normal, lhs_scales, m, k, block_m, block_k, 1)
tensor_map_scales_b = make_2d_tma_scales_desc(GemmType.Normal, rhs_scales, n, k, block_n, block_k, 1)
kwargs = {
# Templated arguments
'GEMM_TYPE': GemmType.Normal,
'NUM_TMA_THREADS': num_tma_threads,
'NUM_MATH_THREADS_PER_GROUP': num_math_threads_per_group,
'M': m, 'N': n, 'K': aligned_k,
'NUM_GROUPS': 1,
'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k,
'NUM_STAGES': num_stages,
'NUM_LAST_STAGES': num_last_stages,
'NUM_TMA_MULTICAST': tma_multicast_config[0],
'IS_TMA_MULTICAST_ON_A': tma_multicast_config[1],
# Runtime arguments
'NUM_SMS': num_sms,
'SMEM_SIZE': smem_config[0],
'TENSOR_MAP_A': tensor_map_a,
'TENSOR_MAP_B': tensor_map_b,
'TENSOR_MAP_SCALES_A': tensor_map_scales_a,
'TENSOR_MAP_SCALES_B': tensor_map_scales_b,
'TENSOR_MAP_D': tensor_map_d,
'STREAM': torch.cuda.current_stream().cuda_stream,
'DEVICE_INDEX': out.device.index
}
# Generate, build and run the kernel
code = FP8WGradGemmRuntime.generate(kwargs)
runtime = build('wgrad_gemm_fp8_fp8_fp32_nt', code, FP8WGradGemmRuntime, kwargs)
runtime(**kwargs)
def k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(lhs: Tuple[torch.Tensor, torch.Tensor],
rhs: Tuple[torch.Tensor, torch.Tensor],
out: torch.Tensor,
batch_sizes: List[int]):
"""
Perform a k-grouped weight gradient GEMM with FP8 inputs and FP32 output, with 1x128 LHS scaling and 1x128 RHS scaling.
Results will be accumulated into the output tensor.
Requirements:
This function handles multiple batches with varying k-dimensions, processing each batch sequentially.
Each batch's LHS, RHS, and output tensors must be contiguous.
The RHS and RHS scaling factors are required to be transposed.
The LHS scaling and RHS scaling tensors require a TMA-aligned transposed format.
Arguments:
lhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of LHS data,
and the flattened shape is `[sum(m * k for k in batch_sizes)]`, where m is the number of rows.
The second element is an FP32 scaling tensor for LHS with shape `[⌈k / 128⌉ for k in batch_sizes), m]`,
representing the per-128-channel scaling factors.
rhs: The first element is a flattened FP8 tensor (typed `torch.float8_e4m3fn`) containing all batches of RHS data,
and the flattened shape is `[sum(n * k for k in batch_sizes)]`, where n is the number of rows.
The second element is an FP32 scaling tensor for RHS with shape `[⌈k / 128⌉ for k in batch_sizes), n]`,
representing the per-128-channel scaling factors.
out: The FP32 output tensor of shape [num_batches, m, n], which will be accumulated.
batch_sizes: A list of integers specifying the k-dimension for each batch.
"""
lhs, lhs_scales = lhs[0].view(-1), lhs[1]
rhs, rhs_scales = rhs[0].view(-1), rhs[1]
num_batches, m, n = out.shape
lhs_offset, rhs_offset, scales_offset = 0, 0, 0
for i in range(num_batches):
k = batch_sizes[i]
lhs_slice = lhs[lhs_offset:lhs_offset + m * k].view(m, k)
rhs_slice = rhs[rhs_offset:rhs_offset + n * k].view(n, k)
lhs_scales_slice = lhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
rhs_scales_slice = rhs_scales[scales_offset:scales_offset + ceil_div(k, 128)]
wgrad_gemm_fp8_fp8_fp32_nt((lhs_slice, lhs_scales_slice), (rhs_slice, rhs_scales_slice), out[i])
lhs_offset += m * k
rhs_offset += n * k
scales_offset += ceil_div(k, 128)

View File

@@ -0,0 +1,3 @@
from . import bench, numeric
from .bench import *
from .numeric import *

View File

@@ -1,8 +1,6 @@
import os
import sys
import time
import torch
import torch.distributed as dist
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
@@ -31,7 +29,7 @@ def bench(fn, num_warmups: int = 5, num_tests: int = 10,
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_tests
return start_event.elapsed_time(end_event) / num_tests / 1e3
class empty_suppress:
@@ -77,8 +75,9 @@ class suppress_stdout_stderr:
self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output: bool = False,
trace_path: str = None, barrier_comm_profiling: bool = False, flush_l2: bool = True,
def bench_kineto(fn, kernel_names, num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: str = None, flush_l2: bool = True,
with_multiple_kernels: bool = False):
# Conflict with Nsight Systems
using_nsys = int(os.environ.get('DG_NSYS_PROFILING', 0))
@@ -96,12 +95,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
with profiler:
for i in range(2):
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
if barrier_comm_profiling:
lhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
rhs = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
lhs @ rhs
dist.all_reduce(torch.ones(1, dtype=torch.float, device='cuda'))
for _ in range(num_tests):
if flush_l2:
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
@@ -116,7 +109,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
# Parse the profiling table
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tupled = isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
assert all([isinstance(name, str) for name in kernel_names])
@@ -145,21 +138,4 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30, suppress_kineto_output:
break
kernel_times.append(total_time / total_num)
return tuple(kernel_times) if is_tupled else kernel_times[0]
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def count_bytes(tensors):
total = 0
for t in tensors:
if isinstance(t, tuple):
total += count_bytes(t)
else:
total += t.numel() * t.element_size()
return total
return tuple(kernel_times) if is_tuple else kernel_times[0]

View File

@@ -0,0 +1,19 @@
import torch
from typing import Iterable
def calc_diff(x: torch.Tensor, y: torch.Tensor):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def count_bytes(*tensors):
total = 0
for t in tensors:
if isinstance(t, (tuple, list)):
total += count_bytes(*t)
elif t is not None:
total += t.numel() * t.element_size()
return total

View File

@@ -0,0 +1,3 @@
from . import math, layout
from .layout import *
from .math import *

11
deep_gemm/utils/layout.py Normal file
View File

@@ -0,0 +1,11 @@
from deep_gemm_cpp import (
get_tma_aligned_size,
get_mk_alignment_for_contiguous_layout,
get_mn_major_tma_aligned_tensor,
get_mn_major_tma_aligned_packed_ue8m0_tensor,
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
)
# Some alias
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout
get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout

48
deep_gemm/utils/math.py Normal file
View File

@@ -0,0 +1,48 @@
import torch
from typing import Tuple
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def ceil_to_ue8m0(x: torch.Tensor):
assert x.view(-1).amax().item() > 0
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), sf
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(0) % 128 == 0
m, n = x.shape
x_view = x.view(-1, 128, n)
x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf
def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2))

25
develop.sh Executable file
View File

@@ -0,0 +1,25 @@
# Change current directory into project root
original_dir=$(pwd)
script_dir=$(realpath "$(dirname "$0")")
cd "$script_dir"
# Link CUTLASS includes
ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include
ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include
# Remove old dist file, build, and build
rm -rf build dist
rm -rf *.egg-info
python setup.py build
# Find the .so file in build directory and create symlink in current directory
so_file=$(find build -name "*.so" -type f | head -n 1)
if [ -n "$so_file" ]; then
ln -sf "$so_file" .
else
echo "Error: No SO file found in build directory" >&2
exit 1
fi
# Open users' original directory
cd "$original_dir"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 558 KiB

View File

@@ -1,8 +0,0 @@
#include "deep_gemm/fp8_gemm.cuh"
#include "deep_gemm/fp8_wgrad_gemm.cuh"
using namespace deep_gemm;
int main() {
return 0;
}

13
install.sh Executable file
View File

@@ -0,0 +1,13 @@
# Change current directory into project root
original_dir=$(pwd)
script_dir=$(realpath "$(dirname "$0")")
cd "$script_dir"
# Remove old dist file, build, and install
rm -rf build dist
rm -rf *.egg-info
python setup.py bdist_wheel
pip install dist/*.whl
# Open users' original directory
cd "$original_dir"

View File

@@ -2,34 +2,28 @@ import os
import setuptools
import shutil
import subprocess
from setuptools import find_packages
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
from torch.utils.cpp_extension import CppExtension, CUDA_HOME
current_dir = os.path.dirname(os.path.realpath(__file__))
jit_include_dirs = ('deep_gemm/include/deep_gemm', )
third_party_include_dirs = (
cxx_flags = ['-std=c++20', '-O3', '-fPIC', '-Wno-psabi']
sources = ['csrc/python_api.cpp']
build_include_dirs = [
f'{CUDA_HOME}/include',
'deep_gemm/include',
'third-party/cutlass/include',
'third-party/fmt/include',
]
build_libraries = ['cuda', 'cudart']
build_library_dirs = [
f'{CUDA_HOME}/lib64',
f'{CUDA_HOME}/lib64/stub'
]
third_party_include_dirs = [
'third-party/cutlass/include/cute',
'third-party/cutlass/include/cutlass',
)
class PostDevelopCommand(develop):
def run(self):
develop.run(self)
self.make_jit_include_symlinks()
@staticmethod
def make_jit_include_symlinks():
# Make symbolic links of third-party include directories
for d in third_party_include_dirs:
dirname = d.split('/')[-1]
src_dir = f'{current_dir}/{d}'
dst_dir = f'{current_dir}/deep_gemm/include/{dirname}'
assert os.path.exists(src_dir)
if os.path.exists(dst_dir):
assert os.path.islink(dst_dir)
os.unlink(dst_dir)
os.symlink(src_dir, dst_dir, target_is_directory=True)
]
class CustomBuildPy(build_py):
@@ -37,9 +31,21 @@ class CustomBuildPy(build_py):
# First, prepare the include directories
self.prepare_includes()
# Then run the regular build
# Second, make clusters' cache setting default into `envs.py`
self.generate_default_envs()
# Finally, run the regular build
build_py.run(self)
def generate_default_envs(self):
code = '# Pre-installed environment variables\n'
code += 'persistent_envs = dict()\n'
for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_DISABLE_SHORTCUT_CACHE'):
code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else ''
with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f:
f.write(code)
def prepare_includes(self):
# Create temporary build directory instead of modifying package directory
build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include')
@@ -67,19 +73,28 @@ if __name__ == '__main__':
except:
revision = ''
# noinspection PyTypeChecker
setuptools.setup(
name='deep_gemm',
version='1.0.0' + revision,
packages=['deep_gemm', 'deep_gemm/jit', 'deep_gemm/jit_kernels'],
version='2.0.0' + revision,
packages=find_packages('.'),
package_data={
'deep_gemm': [
'include/deep_gemm/*',
'include/deep_gemm/**/*',
'include/cute/**/*',
'include/cutlass/**/*',
]
},
ext_modules=[
CppExtension(name='deep_gemm_cpp',
sources=sources,
include_dirs=build_include_dirs,
libraries=build_libraries,
library_dirs=build_library_dirs,
extra_compile_args=cxx_flags)
],
zip_safe=False,
cmdclass={
'develop': PostDevelopCommand,
'build_py': CustomBuildPy,
},
)

212
tests/generators.py Normal file
View File

@@ -0,0 +1,212 @@
import enum
import random
import torch
from typing import Generator, Tuple, List
from deep_gemm.utils import (
align, ceil_div,
per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8,
get_mk_alignment_for_contiguous_layout
)
class KernelType(enum.Enum):
# For SM100 GEMMs
Kernel1D1D = 0
Kernel1D2D = 1
def is_1d1d(self):
return self.value == 0
def is_1d2d(self):
return self.value == 1
class MajorTypeAB(enum.Enum):
KMajor = 0
MNMajor = 1
def is_k_major(self):
return self.value == 0
def is_mn_major(self):
return self.value == 1
def get_arch_major() -> int:
major, minor = torch.cuda.get_device_capability()
return major
def get_ue8m0_usage(kernel_type: KernelType) -> bool:
if get_arch_major() == 9:
return False
return kernel_type.is_1d1d()
def get_kernel_types() -> tuple:
return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, KernelType.Kernel1D2D)
def get_out_dtype() -> tuple:
return (torch.bfloat16, ) if get_arch_major() == 9 else (torch.bfloat16, torch.float)
def get_major_ab(freeze_a: bool) -> tuple:
if get_arch_major() == 9:
return ((MajorTypeAB.KMajor, MajorTypeAB.KMajor), )
if freeze_a:
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor)
return (MajorTypeAB.KMajor, MajorTypeAB.KMajor), (MajorTypeAB.KMajor, MajorTypeAB.MNMajor), \
(MajorTypeAB.MNMajor, MajorTypeAB.KMajor), (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
def enumerate_normal() -> Generator:
for kernel_type in get_kernel_types():
for m in (128, 4096):
for n, k in [(2112, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)]:
for major_a, major_b in get_major_ab(False):
for out_dtype in get_out_dtype():
for accumulate in (False, ) if out_dtype == torch.bfloat16 or kernel_type.is_1d2d() else (False, True):
yield kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype
def enumerate_m_grouped_contiguous() -> Generator:
for kernel_type in get_kernel_types():
for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)):
for major_a, major_b in get_major_ab(True):
yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b
def enumerate_m_grouped_masked() -> Generator:
max_m = 4096
for kernel_type in get_kernel_types():
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
for n, k in ((4096, 7168), (7168, 2048), ):
yield kernel_type, num_groups, max_m, m, n, k
def enumerate_k_grouped_contiguous():
# TODO: support SM90 kernels
if get_arch_major() == 9:
return []
# Must with FP32 accumulation and 1D1D kernels
for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64
( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32
(16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
yield num_groups, m, n, ks, expected_k_per_group
def enumerate_sf_layout():
for with_transpose in (True, False):
for mn in (4096, 4097, 8192):
for k in (128, 7168, 7296):
for num_groups in (1, 2, 4) if with_transpose else (1, ):
if num_groups > 1 and (mn * ceil_div(k, 128)) % 4 != 0:
continue
if not with_transpose and mn % 4 != 0:
continue
yield mn, k, with_transpose, num_groups
def enumerate_k_grouped_sf_layout():
alignment = get_mk_alignment_for_contiguous_layout()
assert alignment % 128 == 0
for mn in (4096, 7168):
for num_groups, avg_k in ((16, 2048), (8, 4096), (72, 384), (128, 256)):
ks = [align(int(random.uniform(0.7, 1.3) * avg_k), alignment) for _ in range(num_groups)]
yield mn, ks, num_groups
def generate_normal(m: int, n: int, k: int,
major_a: MajorTypeAB, major_b: MajorTypeAB,
accumulate: bool, out_dtype: torch.dtype,
use_ue8m0: bool):
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \
torch.empty((m, n), device='cuda', dtype=out_dtype)
c = d if accumulate else None
ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype)
a_fp8, b_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0), per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0)
a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1])
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1])
return a_fp8, b_fp8, c, d, ref_d
def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int,
major_a: MajorTypeAB, major_b: MajorTypeAB, use_ue8m0: bool) -> \
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms]
m = sum(aligned_ms)
a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
start = 0
for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)):
actual_end = start + actual_m
aligned_end = start + aligned_m
m_indices[start:actual_end] = i
m_indices[actual_end:aligned_end] = -1
ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t()
start = aligned_end
ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d)
assert major_a.is_k_major()
a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
for i in range(num_groups):
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1])
return m, a_fp8, b_fp8, m_indices, d, ref_d
def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, use_ue8m0: bool) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
ref_d = torch.einsum('gmk,gnk->gmn', a, b)
a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float))
b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float))
for i in range(num_groups):
a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0)
b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0)
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
assert masked_m.amax().item() <= max_m
return a_fp8, b_fp8, masked_m, d, ref_d
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, ks: List[int], use_ue8m0: bool):
assert get_mk_alignment_for_contiguous_layout() % 128 == 0
k = sum(ks)
a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16)
b = torch.randn((k, n), device='cuda', dtype=torch.bfloat16)
c = torch.randn((num_groups, m, n), device='cuda', dtype=torch.float) * 32
d = c
ref_d = torch.empty_like(c)
start = 0
for i, group_k in enumerate(ks):
end = start + group_k
ref_d[i] = c[i] + (a[start:end].T @ b[start:end])
start = end
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
return k, a_fp8, b_fp8, c, d, ref_d

View File

@@ -1,297 +1,161 @@
# PyTorch has its own NVRTC, which may have a lower version than the system
# So try to disable PyTorch's NVRTC, or import NVRTC before PyTorch
import cuda.bindings.nvrtc as nvrtc
print(f'NVRTC version: {nvrtc.nvrtcVersion()[1:]}')
import copy
import random
import time
import torch
from typing import List, Tuple
import deep_gemm
from deep_gemm import bench_kineto, calc_diff, ceil_div, get_col_major_tma_aligned_tensor
from deep_gemm.jit_kernels.utils import get_m_alignment_for_contiguous_layout
from deep_gemm.testing import (
bench, bench_kineto,
calc_diff, count_bytes
)
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
pad_size = (128 - (n % 128)) % 128
x = torch.nn.functional.pad(x, (0, pad_size), value=0) if pad_size > 0 else x
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
def construct(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = x @ y.t()
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
return x_fp8, y_fp8, out, ref_out
def construct_contiguous_grouped(num_groups: int, expected_m_per_group: int, k: int, n: int) -> \
Tuple[int, Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
alignment = get_m_alignment_for_contiguous_layout()
group_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)]
m = sum([ceil_div(x, alignment) * alignment for x in group_ms])
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
m_indices = torch.empty(m, device='cuda', dtype=torch.int32)
out = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
ref_out = torch.randn((m, n), device='cuda', dtype=torch.bfloat16)
start = 0
for i, group_m in enumerate(group_ms):
actual_end = start + group_m
aligned_end = start + ceil_div(group_m, alignment) * alignment
m_indices[start:actual_end] = i
m_indices[actual_end:aligned_end] = -1
ref_out[start:aligned_end] = x[start:aligned_end] @ y[i].t()
start = aligned_end
ref_out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_out), ref_out)
assert m % 4 == 0, f'TMA alignment error: {m}'
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
return m, x_fp8, y_fp8, m_indices, out, ref_out
def construct_masked_grouped(num_groups: int, max_m: int, expected_m_per_group: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16)
out = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16)
ref_out = torch.einsum('gmk,gnk->gmn', x, y)
assert max_m % 4 == 0, f'TMA alignment error: {max_m}'
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, k // 128), device='cuda', dtype=torch.float))
y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), k // 128), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
# Transpose earlier so that the testing will not trigger transposing kernels
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
# Construct mask
masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
assert masked_m.amax().item() <= max_m
return x_fp8, y_fp8, masked_m, out, ref_out
def construct_wgrad(m: int, k: int, n: int) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
residual = torch.randn((m, n), device='cuda', dtype=torch.float) * 10
out = residual.clone()
ref_out = residual + (x.float() @ y.float().t())
x_fp8 = per_token_cast_to_fp8(x)
y_fp8 = per_token_cast_to_fp8(y)
# NOTES: please do inplace add on the `out` later
return x_fp8, y_fp8, residual, out, ref_out
def construct_k_grouped_wgrad(m: int, n: int, k_sizes: List[int]) -> \
Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor, List[int]]:
num_groups, total_k = len(k_sizes), sum(k_sizes)
x_flat = torch.empty((m * total_k,), device='cuda', dtype=torch.bfloat16)
y_flat = torch.empty((n * total_k,), device='cuda', dtype=torch.bfloat16)
out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
ref_out = torch.zeros((num_groups, m, n), device='cuda', dtype=torch.float)
# Fill tensors with data and compute reference output
x_offset, y_offset = 0, 0
for idx, k in enumerate(k_sizes):
x_chunk = torch.randn((m, k), device='cuda', dtype=torch.bfloat16)
y_chunk = torch.randn((n, k), device='cuda', dtype=torch.bfloat16)
x_flat[x_offset:x_offset + m * k].copy_(x_chunk.flatten())
y_flat[y_offset:y_offset + n * k].copy_(y_chunk.flatten())
ref_out[idx] = x_chunk.float() @ y_chunk.float().t()
x_offset += m * k
y_offset += n * k
x_fp8_flat = torch.empty_like(x_flat, dtype=torch.float8_e4m3fn)
y_fp8_flat = torch.empty_like(y_flat, dtype=torch.float8_e4m3fn)
total_scale_factors = sum(ceil_div(k, 128) for k in k_sizes)
x_scales = torch.empty((total_scale_factors, m), device='cuda', dtype=torch.float)
y_scales = torch.empty((total_scale_factors, n), device='cuda', dtype=torch.float)
# Cast to FP8 and prepare scale factors
x_offset, y_offset, scale_offset = 0, 0, 0
for k in k_sizes:
x_fp8_chunk, x_scale_chunk = per_token_cast_to_fp8(x_flat[x_offset:x_offset + m * k].view(m, k))
y_fp8_chunk, y_scale_chunk = per_token_cast_to_fp8(y_flat[y_offset:y_offset + n * k].view(n, k))
x_fp8_flat[x_offset:x_offset + m * k].copy_(x_fp8_chunk.flatten())
y_fp8_flat[y_offset:y_offset + n * k].copy_(y_fp8_chunk.flatten())
num_scales = ceil_div(k, 128)
x_scales[scale_offset:scale_offset + num_scales].copy_(x_scale_chunk.T)
y_scales[scale_offset:scale_offset + num_scales].copy_(y_scale_chunk.T)
x_offset += m * k
y_offset += n * k
scale_offset += num_scales
return (x_fp8_flat, x_scales), (y_fp8_flat, y_scales), out, ref_out, k_sizes
from generators import (
KernelType, get_ue8m0_usage,
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
)
def test_gemm() -> None:
print('Testing GEMM:')
for m in (64, 128, 4096):
for k, n in [(576, 7168), (7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)]:
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal():
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
out_opt = 'FP32' if out_dtype == torch.float else 'BF16'
acc_opt = f'acc={int(accumulate)}'
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
# noinspection PyShadowingNames
def test_func():
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
for test_alias in (False, True):
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}'
if test_alias:
a = a if major_a.is_k_major() else (a[0].T, a[1].T)
b = b if major_b.is_k_major() else (b[0].T, b[1].T)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
diff = calc_diff(d, ref_d)
assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, '
f'{diff:.5f}, alias={test_alias}')
a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, use_ue8m0=use_ue8m0)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
# Test launch overhead
launch_start_t = time.time_ns()
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
launch_end_t = time.time_ns()
torch.cuda.synchronize()
# noinspection PyShadowingNames
def test_func():
deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}):'
f' launch {(launch_end_t - launch_start_t) / 1e3:4.0f} us | {t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_contiguous() -> None:
print('Testing grouped contiguous GEMM:')
print('Testing m-grouped contiguous GEMM:')
for num_groups, expected_m_per_group, k, n in ((4, 8192, 7168, 4096), (4, 8192, 2048, 7168),
(8, 4096, 7168, 4096), (8, 4096, 2048, 7168),
(32, 256, 7168, 4096), (32, 256, 2048, 7168)):
# NOTES: we should mask the unfilled part before calculating difference
m, x_fp8, y_fp8, m_indices, out, ref_out = construct_contiguous_grouped(num_groups, expected_m_per_group, k, n)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
out = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(out), out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous():
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
for test_alias in (False, True):
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous"
if test_alias:
assert major_a.is_k_major()
b = b if major_b.is_k_major() else (b[0].mT, b[1].mT)
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}'
m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0)
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(x_fp8, y_fp8, out, m_indices)
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast)
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
valid_m = (m_indices != -1).sum().item()
print(f' > Perf ({num_groups=:2}, {expected_m_per_group=:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, {kernel_opt}, layout={major_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s')
print()
def test_m_grouped_gemm_masked() -> None:
print('Testing grouped masked GEMM:')
print('Testing m-grouped masked GEMM:')
for num_groups, expected_m_per_group in ((1, 1024), (2, 512), (4, 256)):
for k, n in ((7168, 4096), (2048, 7168), ):
# Test correctness
for i in range(10):
x_fp8, y_fp8, masked_m, out, ref_out = construct_masked_grouped(num_groups, 4096, expected_m_per_group, k, n)
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group)
for j in range(num_groups):
diff = calc_diff(out[j, :masked_m[j].item()], ref_out[j, :masked_m[j].item()])
assert diff < 0.001, f'{expected_m_per_group=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}'
# TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease.
for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked():
kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D'
use_ue8m0 = get_ue8m0_usage(kernel_type)
disable_ue8m0_cast = not use_ue8m0
# noinspection PyShadowingNames
def test_func():
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(x_fp8, y_fp8, out, masked_m, expected_m_per_group)
# Test correctness
for i in range(10):
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
for j in range(num_groups):
diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()])
assert diff < 0.001, f'{m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}'
# Test performance with fixed shapes
# noinspection PyUnboundLocalVariable
valid_m = masked_m.sum().item()
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): {t * 1e6:4.0f} us | '
f'throughput: {2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(valid_m * k + num_groups * k * n + valid_m * n * 2) / 1e9 / t:4.0f} GB/s')
# Construct full cases
a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0)
# noinspection PyShadowingNames
def test_func():
deep_gemm.fp8_m_grouped_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast)
# Test performance with fixed shapes
valid_m = masked_m.sum().item()
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s')
print()
def test_wgrad_gemm():
print('Testing weight gradient GEMM:')
def test_k_grouped_gemm_contiguous() -> None:
print('Testing k-grouped contiguous GEMM:')
for k in (4096, 8192):
for m, n in ((7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), (2048, 7168)):
# Test correctness
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
diff = calc_diff(out, ref_out)
assert diff < 0.001, f'{m=}, {k=}, {n=}, {diff:.5f}'
for num_groups, m, n, ks, expected_k_per_group in enumerate_k_grouped_contiguous():
use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D)
# Construct new tensors only once to avoid L2 cache acceleration (creating them puts them in L2)
x_fp8, y_fp8, residual, out, ref_out = construct_wgrad(m, k, n)
for test_empty_groups in (False, True):
new_ks = copy.deepcopy(ks)
if test_empty_groups:
new_ks[random.randint(0, num_groups - 1)] = 0
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, new_ks, use_ue8m0=use_ue8m0)
new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda')
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c=c)
diff = calc_diff(d, ref_d)
assert diff < 0.001, f'{m=}, {n=}, {k=}, {i=}, {diff:.5f}'
# noinspection PyShadowingNames
def test_func():
deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out)
# Test performance
k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, ks, use_ue8m0=use_ue8m0)
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True)
print(f' > Performance (m={m:5}, n={n:5}, k={k:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * m * n * k / t / 1e12:4.0f} TFLOPS, '
f'{(m * k + k * n + m * n * 2) / 1e9 / t:4.0f} GB/s')
print()
# noinspection PyShadowingNames
def test_func():
deep_gemm.k_grouped_fp8_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c=c)
def test_k_grouped_wgrad_gemm():
print('Testing grouped weight gradient GEMM:')
for num_groups, base_k in ((4, 4096), (4, 8192), (8, 4096)):
for m, n in ((7168, 4096), (2048, 7168)):
# Vary k sizes around base_k
k_sizes = [base_k + random.randint(-1, 1) * 128 for _ in range(num_groups - 1)]
k_sizes.append(base_k * num_groups - sum(k_sizes))
# Test correctness
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
for idx in range(num_groups):
diff = calc_diff(out[idx], ref_out[idx])
assert diff < 0.001, f'{num_groups=}, {m=}, {n=}, k={k_sizes[idx]}, batch={idx}, {diff:.5f}'
# Construct new tensors to avoid L2 cache acceleration
x_fp8, y_fp8, out, ref_out, k_sizes = construct_k_grouped_wgrad(m, n, k_sizes)
total_k = sum(k_sizes)
def test_func():
deep_gemm.k_grouped_wgrad_gemm_fp8_fp8_fp32_nt(x_fp8, y_fp8, out, k_sizes)
t = bench_kineto(test_func, 'fp8_wgrad_gemm', suppress_kineto_output=True, with_multiple_kernels=True) * num_groups
print(f' > Performance ({num_groups=}, m={m:5}, n={n:5}, avg_k={total_k//num_groups:5}): {t * 1e6:4.0f} us | '
f'throughput: {2 * num_groups * m * n * (total_k/num_groups) / t / 1e12:4.0f} TFLOPS, '
f'{(m * total_k + n * total_k + num_groups * m * n * 2) / 1e9 / t:4.0f} GB/s')
t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True)
print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s')
print()
@@ -307,6 +171,4 @@ if __name__ == '__main__':
test_gemm()
test_m_grouped_gemm_contiguous()
test_m_grouped_gemm_masked()
test_wgrad_gemm()
test_k_grouped_wgrad_gemm()
test_k_grouped_gemm_contiguous()

View File

@@ -1,98 +0,0 @@
import ctypes
import os
import torch
import cuda.bindings.driver as cbd
from typing import Any, Dict
from deep_gemm import jit
# Essential debugging staffs
os.environ['DG_JIT_DEBUG'] = os.getenv('DG_JIT_DEBUG', '1')
os.environ['DG_JIT_DISABLE_CACHE'] = os.getenv('DG_JIT_DISABLE_CACHE', '1')
class VectorAddRuntime(jit.Runtime):
def __init__(self, path: str) -> None:
super().__init__(path)
@staticmethod
def generate(kwargs: Dict[str, Any]) -> str:
return f"""
#ifdef __CUDACC_RTC__
#include <deep_gemm/nvrtc_std.cuh>
#else
#include <cuda.h>
#endif
#include <cuda_fp8.h>
#include <cuda_bf16.h>
template <typename T>
__global__ void vector_add(T* a, T* b, T* c, uint32_t n) {{
uint32_t i = blockDim.x * blockIdx.x + threadIdx.x;
if (i < n) {{
c[i] = a[i] + b[i];
}}
}}
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&vector_add<{kwargs['T']}>);
}}
"""
# noinspection PyShadowingNames,PyMethodOverriding
@staticmethod
def launch(kernel: cbd.CUkernel, kwargs: Dict[str, Any]) -> cbd.CUresult:
assert kwargs['A'].shape == kwargs['B'].shape == kwargs['C'].shape
assert kwargs['A'].device == kwargs['B'].device == kwargs['C'].device
assert kwargs['A'].dim() == 1
config = cbd.CUlaunchConfig()
config.gridDimX = (kwargs['A'].numel() + 127) // 128
config.gridDimY = 1
config.gridDimZ = 1
config.blockDimX = 128
config.blockDimY = 1
config.blockDimZ = 1
config.hStream = kwargs['STREAM']
arg_values = (
kwargs['A'].data_ptr(),
kwargs['B'].data_ptr(),
kwargs['C'].data_ptr(),
kwargs['A'].numel(),
)
arg_types = (
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_uint32,
)
return cbd.cuLaunchKernelEx(config, kernel, (arg_values, arg_types), 0)[0]
if __name__ == '__main__':
print('Generated code:')
kwargs = {'T': 'float'}
code = VectorAddRuntime.generate(kwargs)
print(code)
print()
for compiler_name in ('NVCC', 'NVRTC'):
# Get compiler
compiler_cls = getattr(jit, f'{compiler_name}Compiler')
print(f'Compiler: {compiler_name}, version: {compiler_cls.__version__()}')
# Build
print('Building ...')
func = compiler_cls.build('test_func', code, VectorAddRuntime, kwargs)
# Run and check
a = torch.randn((1024, ), dtype=torch.float32, device='cuda')
b = torch.randn((1024, ), dtype=torch.float32, device='cuda')
c = torch.empty_like(a)
ret = func(A=a, B=b, C=c, STREAM=torch.cuda.current_stream().cuda_stream)
assert ret == cbd.CUresult.CUDA_SUCCESS, ret
torch.testing.assert_close(c, a + b)
print(f'JIT test for {compiler_name} passed\n')

104
tests/test_layout.py Normal file
View File

@@ -0,0 +1,104 @@
import time
import torch
import random
from deep_gemm.testing import bench_kineto, count_bytes
from deep_gemm.utils import (
align, ceil_div,
per_token_cast_to_fp8, per_channel_cast_to_fp8,
get_tma_aligned_size,
get_mn_major_tma_aligned_packed_ue8m0_tensor,
get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor
)
from generators import (
enumerate_sf_layout,
enumerate_k_grouped_sf_layout
)
def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.float and x.dim() in (2, 3)
# First, convert into UE8M0 `uint8_t`
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
# Second, make padded packed tensors
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)
padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8)
padded[:, :mn, :k] = ue8m0_tensor
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4)
# Finally, transpose
transposed = torch.zeros((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int).mT
transposed[:, :, :] = padded
aligned_x = transposed[:, :mn, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def test_sf_layout_kernels() -> None:
print('Testing SF layout kernels:')
for mn, k, with_transpose, num_groups in enumerate_sf_layout():
x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda')
x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=True)
fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1)
fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2)
# Correctness
packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf)
assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}'
assert packed_sf.shape == ref_packed_sf.shape
assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())])
# Test launch overhead
launch_start_t = time.time_ns()
get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf)
launch_end_t = time.time_ns()
# Performance
t = bench_kineto(lambda: get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf), 'pack_fp32_into_ue8m0')
print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}): '
f'launch {(launch_end_t - launch_start_t) / 1e3:3.0f} us | {t * 1e6:4.0f} us | '
f'{count_bytes(fp32_sf, packed_sf) / 1e9 / t:4.0f} GB/s')
print()
def test_k_grouped_sf_layout_kernels() -> None:
print('Testing k-grouped SF layout kernels:')
for mn, ks, num_groups in enumerate_k_grouped_sf_layout():
sf_ks = [k // 128 for k in ks]
packed_sf_ks = [ceil_div(k, 512) for k in ks]
ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda')
x = torch.randn((sum(ks), mn), dtype=torch.bfloat16, device='cuda')
x, fp32_sf = per_channel_cast_to_fp8(x, use_ue8m0=True)
# Correctness
packed_sf = get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks)
split_packed_sf = packed_sf.split(packed_sf_ks)
split_fp32_sf = fp32_sf.split(sf_ks)
for i in range(num_groups):
ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(split_fp32_sf[i].T).T
assert torch.equal(split_packed_sf[i], ref_packed_sf), f'{i=}'
# Performance
t = bench_kineto(lambda: get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks), 'pack_fp32_into_ue8m0')
print(f' > Perf ({num_groups=:3}, {mn=:5}, sum_k={sum(ks):5}):'
f'{t * 1e6:4.0f} us | '
f'{count_bytes(fp32_sf, packed_sf, ks_tensor) / 1e9 / t:4.0f} GB/s')
print()
if __name__ == '__main__':
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(1)
random.seed(1)
test_sf_layout_kernels()
test_k_grouped_sf_layout_kernels()

1
third-party/fmt vendored Submodule

Submodule third-party/fmt added at 553ec11ec0