[Public release 26/04] Introducing Mega MoE, FP4 Indexer and other features/fixes (#304)

* Merge with private repo

* Update README

* Update README

* Update README

* Add PyTorch requirements

* Fix sync scopes for MQA logits (#256)

* Update README
This commit is contained in:
Chenggang Zhao
2026-04-17 09:45:14 +08:00
committed by GitHub
parent d30fc36c8f
commit 7f2a703ed5
109 changed files with 12101 additions and 3219 deletions

View File

@@ -3,8 +3,7 @@ cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi -Wno-deprecated-declarations")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
@@ -18,11 +17,11 @@ find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/targets/x86_64-linux/include ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
# The main Python API entrance

120
README.md
View File

@@ -1,13 +1,16 @@
# DeepGEMM
DeepGEMM is a library designed for clean and efficient General Matrix Multiplications (GEMMs). It supports FP8 and BF16 (working in progress) for both normal and Mix-of-Experts (MoE) grouped scenarios. Written in CUDA, the library has no kernel compilation need during installation, by compiling all kernels at runtime using a lightweight Just-In-Time (JIT) module.
DeepGEMM is a unified, high-performance tensor core kernel library that brings together the key computation primitives of modern large language models — GEMMs (FP8, FP4, BF16), fused MoE with overlapped communication (Mega MoE), MQA scoring for the lightning indexer, HyperConnection (HC), and more — into a single, cohesive CUDA codebase. All kernels are compiled at runtime via a lightweight Just-In-Time (JIT) module, requiring no CUDA compilation during installation.
DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), it avoids heavy reliance on their templates or algebras. Instead, the library is designed for simplicity, with only a limited number of core kernel functions. This makes it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.
DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), but avoids heavy reliance on their templates or algebras. The library is designed for simplicity, with only a limited number of core kernel functions, making it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
## News
- 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more.
- Performance comparison will be posted later.
- Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details.
- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2.
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module.
@@ -19,27 +22,6 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.
## Roadmap
- [x] More correctness tests for grouped-contiguous layout
- [x] Shared memory swizzling for output
- [x] MoE scheduler with TMA multicast compatibility
- [x] Fix TMA multicast compatibility for indivisible shapes
- [x] Skip useless computation on M
- [x] NVRTC as a faster compiler
- [x] Sanitizer for testing
- [x] Weight gradient kernels for dense models
- [x] Weight gradient kernels for MoE models
- [ ] Better `get_best_configs` modeling
- [ ] CUDA PDL support
- [ ] Larger TMA multicast size for some shapes
- [x] MMA template refactor with CUTLASS
- [x] Remove shape limitations on N and K
- [x] BF16 kernels
- [ ] Split/stream-k optimizations
- [ ] Ampere kernels
- [ ] Polish docs
## Quick start
### Requirements
@@ -65,11 +47,6 @@ cd DeepGEMM
# Link some essential includes and build the CPP JIT module
cat develop.sh
./develop.sh
# Test all GEMM implements
python tests/test_layout.py
python tests/test_attention.py
python tests/test_core.py
```
### Installation
@@ -134,17 +111,47 @@ out_ij = out_ij.sum() # Scalar
For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`.
#### Mega MoE
Mega MoE fuses and overlaps EP dispatch, linear 1 (FP8xFP4), SwiGLU, linear 2 (FP8xFP4), and EP combine into a single mega-kernel, overlapping NVLink communication and tensor core computation. It requires multi-process launch with symmetric memory. Usage:
```python
# Allocate symmetric memory buffer
# NOTES: requires PyTorch >= 2.9
buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden
)
# Transform weights (FP4 with UE8M0 SF) into the required layout
transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
# Copy inputs into the buffer before each call
# You may fuse these into previous kernels
buffer.x[:num_tokens].copy_(x_fp8)
buffer.x_sf[:num_tokens].copy_(x_sf)
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)
# Run the fused mega MoE kernel
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
deep_gemm.fp8_fp4_mega_moe(y, transformed_l1, transformed_l2, buffer)
```
For the full example with multi-process setup and benchmarking, please refer to `tests/test_mega_moe.py`.
#### Utilities
The library provides some utility functions besides the above kernels:
- `deep_gemm.set_num_sms`: set the maximum SM count to use
- `deep_gemm.get_num_sms`: get the current SM maximum count (return the device SM count if not set)
- `deep_gemm.set_tc_util`: set an approximated tensor core utilization ratio
- `deep_gemm.get_tc_util`: get the current tensor core utilization ratio
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into required layout
- `deep_gemm.set_num_sms` / `get_num_sms`: set/get the maximum SM count to use
- `deep_gemm.set_tc_util` / `get_tc_util`: set/get an approximated tensor core utilization ratio
- `deep_gemm.set_pdl` / `get_pdl`: enable/disable Programmatic Dependent Launch (PDL)
- `deep_gemm.set_mk_alignment_for_contiguous_layout` / `get_mk_alignment_for_contiguous_layout`: set/get the group-level M/K alignment for contiguous layout
- `deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout`: get the theoretical minimum M/K alignment
- `deep_gemm.set_ignore_compile_dims`: configure dimensions to ignore during JIT compilation
- `deep_gemm.set_block_size_multiple_of`: constrain block sizes to be multiples of a given value
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into the required layout
- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size
- `deep_gemm.get_mk_alignment_for_contiguous_layout`: get the group-level alignment requirement for grouped contiguous layout
- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor
- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0)
- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel
@@ -152,17 +159,30 @@ The library provides some utility functions besides the above kernels:
The library also provides some environment variables, which may be useful:
- General
- `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default
- JIT cache related
- `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default
- NVCC/NVRTC selections
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default
- `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default
- Compiler options
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default
- Heuristic selection
- `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- JIT cache
- `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default
- Compiler selection
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default
- `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME`
- `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default
- Compiler output
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default
- `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default
- `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default
- Debug and profiling
- `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default
- `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default
- `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default
- `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default
- `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default
- `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default
- Build options
- `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default
- `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default
- `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.
@@ -173,3 +193,15 @@ DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project
## License
This code repository is released under [the MIT License](LICENSE).
## Citation
```bibtex
@misc{deepgemm2025,
title={DeepGEMM: clean and efficient BLAS kernel library on GPU},
author={Chenggang Zhao and Zhean Xu and Liang Zhao and Jiashi Li and Chenhao Xu and Anyi Xu and Shengyu Liu and Kexing Zhou and Kuai Yu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}},
}
```

View File

@@ -5,9 +5,9 @@
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/smxx_fp8_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_clean_logits.hpp"
#endif
@@ -24,8 +24,8 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
@@ -35,9 +35,9 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
check_major_type_cd(d);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a.first);
const auto& [n , k_] = get_shape<2>(b.first);
const auto& [m_, n_] = get_shape<2>(d);
const auto [m , k ] = get_shape<2>(a.first);
const auto [n , k_] = get_shape<2>(b.first);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
@@ -45,7 +45,7 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Check head splits and N
const auto& [left, mid, right] = head_splits;
const auto [left, mid, right] = head_splits;
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
// Do nothing if the problem is empty
@@ -53,16 +53,16 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
return;
// Transform SFA and SFB into compute-required layout
const auto& [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, disable_ue8m0_cast);
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
const auto& epilogue_type = fmt::format("EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
const auto arch_major = device_runtime->get_arch_major();
const auto epilogue_type = fmt::format("epilogue::transform::EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
const auto& major_sfb = get_major_type_ab(sfb);
const auto major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
// NOTES: Only granularity 128 and FP8 are exposed in the API
@@ -73,59 +73,113 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tens
}
}
static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
const std::pair<torch::Tensor, torch::Tensor>& kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const bool& clean_logits,
const int& max_seqlen_k) {
const auto& [seq_len, num_heads, head_dim] = get_shape<3>(q);
const auto& [seq_len_kv, head_dim_] = get_shape<2>(kv.first);
const auto& [seq_len_, num_heads_] = get_shape<2>(weights);
const auto& [seq_len_kv_] = get_shape<1>(kv.second);
static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::optional<torch::Tensor>>& q,
const std::tuple<torch::Tensor, torch::Tensor>& kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const bool& clean_logits,
const int& max_seqlen_k,
const at::ScalarType& logits_dtype) {
const auto [q_fp, q_sf] = q;
const auto [kv_fp, kv_sf] = kv;
const bool is_fp4 = q_sf.has_value();
int seq_len, seq_len_kv, num_heads, head_dim;
DG_HOST_ASSERT(seq_len == seq_len_);
DG_HOST_ASSERT(num_heads == num_heads_ and head_dim == head_dim_);
DG_HOST_ASSERT(seq_len_kv == seq_len_kv_);
DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len);
DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len);
if (is_fp4) {
// Check FP4 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
DG_HOST_ASSERT(q.is_contiguous() and kv.first.is_contiguous());
DG_HOST_ASSERT(kv.second.is_contiguous());
DG_HOST_ASSERT(weights.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous());
// Check SF Q
auto [_seq_len, _num_heads] = get_shape<2>(q_sf.value());
DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads);
DG_HOST_ASSERT(q_sf.value().is_contiguous());
DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32);
DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(kv.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(kv.second.scalar_type() == torch::kFloat);
// Check FP4 KV
int _head_dim;
std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp);
_head_dim *= 2;
DG_HOST_ASSERT(head_dim == _head_dim);
DG_HOST_ASSERT(kv_fp.is_contiguous());
DG_HOST_ASSERT(kv_fp.scalar_type() == kPackedFP4);
// Check SF KV
auto [_seq_len_kv] = get_shape<1>(kv_sf);
DG_HOST_ASSERT(seq_len_kv == _seq_len_kv);
DG_HOST_ASSERT(kv_sf.is_contiguous());
DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kInt32);
} else {
// Check FP8 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check FP4 KV
int _head_dim;
std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp);
DG_HOST_ASSERT(head_dim == _head_dim);
DG_HOST_ASSERT(kv_fp.is_contiguous());
DG_HOST_ASSERT(kv_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check SF KV
auto [_seq_len_kv] = get_shape<1>(kv_sf);
DG_HOST_ASSERT(seq_len_kv == _seq_len_kv);
DG_HOST_ASSERT(kv_sf.is_contiguous());
DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kFloat);
}
// Check weights
auto [_seq_len, _num_heads] = get_shape<2>(weights);
DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads);
DG_HOST_ASSERT(weights.stride(1) == 1);
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
// Check cu_seq_len_k_start
DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len);
DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt);
// Check cu_seq_len_k_end
DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len);
DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt);
constexpr int seq_len_alignment = 4;
// Allocate output
constexpr int block_qh = 128;
constexpr int block_kv = 256;
const auto aligned_seq_len = align(seq_len, seq_len_alignment);
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);
torch::Tensor logits;
int stride_logits;
int aligned_seq_len = align(seq_len, block_q), stride_logits;
if (max_seqlen_k == 0) {
stride_logits = align(seq_len_kv + block_kv, 4);
logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat));
// Logits stride must be 16-byte aligned
stride_logits = align(seq_len_kv + block_kv, 8);
logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype));
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)});
} else {
stride_logits = align(max_seqlen_k, block_kv);
logits = torch::empty({aligned_seq_len, stride_logits}, q.options().dtype(torch::kFloat));
logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype));
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, max_seqlen_k)});
DG_HOST_ASSERT(not clean_logits);
}
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 or arch_major == 10) {
smxx_fp8_mqa_logits(q, kv.first, kv.second, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, seq_len_alignment);
const auto arch_major = device_runtime->get_arch_major();
if (is_fp4 and arch_major == 10) {
sm100_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_mqa_logits(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
@@ -137,23 +191,21 @@ static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
}
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms) {
const bool is_context_lens_2d = context_lens.dim() == 2;
int batch_size = 0, next_n = 0;
if (is_context_lens_2d) {
batch_size = context_lens.size(0);
next_n = context_lens.size(1);
} else {
DG_HOST_ASSERT(context_lens.dim() == 1);
batch_size = context_lens.size(0);
}
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const int batch_size = context_lens.size(0);
const int next_n = context_lens.size(1);
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
DG_HOST_ASSERT(context_lens.is_contiguous());
// Create metadata tensor
auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options());
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9 or arch_major == 10) {
DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32));
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
@@ -162,85 +214,145 @@ static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_
return schedule_metadata;
}
static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& fused_kv_cache,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits) {
const auto& [batch_size, next_n, num_heads, head_dim] = get_shape<4>(q);
const auto& [num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf] = get_shape<4>(fused_kv_cache);
const auto& [batch_size_next_n, num_heads_] = get_shape<2>(weights);
const auto& [batch_size_, max_block_len] = get_shape<2>(block_table);
const auto& [schedule_meta_size, meta_info_size] = get_shape<2>(schedule_meta);
const auto& num_sms = device_runtime->get_num_sms();
const auto& kv_cache_stride_bytes = fused_kv_cache.stride(0);
const auto& block_table_stride = block_table.stride(0);
static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, std::optional<torch::Tensor>>& q,
const torch::Tensor& fused_kv_cache,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits,
const at::ScalarType& logits_dtype) {
const auto [q_fp, q_sf] = q;
const bool is_fp4 = q_sf.has_value();
const bool is_context_lens_2d = context_lens.dim() == 2;
if (is_context_lens_2d) {
const auto& [batch_size__, next_n_] = get_shape<2>(context_lens);
DG_HOST_ASSERT(batch_size == batch_size__ and next_n == next_n_);
torch::Tensor kv_cache, kv_cache_sf;
int batch_size, next_n, num_heads, head_dim;
int num_kv_blocks, block_kv;
int kv_cache_stride_bytes;
int block_table_stride = block_table.stride(0);
int num_sms = device_runtime->get_num_sms();
if (is_fp4) {
// Check FP4 Q
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
// Check SF Q
auto [_batch_size, _next_n, _num_heads] = get_shape<3>(q_sf.value());
DG_HOST_ASSERT(batch_size == _batch_size and next_n == _next_n and num_heads == _num_heads);
DG_HOST_ASSERT(q_sf.value().is_contiguous());
DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32);
// Check fused KV cache
int num_heads_kv, fp4_with_sf_bytes;
std::tie(num_kv_blocks, block_kv, num_heads_kv, fp4_with_sf_bytes) = get_shape<4>(fused_kv_cache);
DG_HOST_ASSERT(block_kv == 32 or block_kv == 64);
DG_HOST_ASSERT(num_heads_kv == 1 and fp4_with_sf_bytes == head_dim / 2 + static_cast<int>(sizeof(int)));
DG_HOST_ASSERT(fused_kv_cache.stride(1) == fp4_with_sf_bytes and fused_kv_cache.stride(3) == 1);
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
// Derive FP4 values and SF tensor
kv_cache_stride_bytes = fused_kv_cache.stride(0);
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(int) == 0);
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim / 2},
{kv_cache_stride_bytes, head_dim / 2, 1},
torch::TensorOptions().dtype(kPackedFP4)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim / 2,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(int)), 1},
torch::TensorOptions().dtype(torch::kInt32)
);
} else {
DG_HOST_ASSERT(context_lens.dim() == 1);
const auto& [batch_size__] = get_shape<1>(context_lens);
DG_HOST_ASSERT(batch_size == batch_size__);
// Check FP8 Q
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check fused KV cache
int num_heads_kv, head_dim_with_sf;
std::tie(num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf) = get_shape<4>(fused_kv_cache);
DG_HOST_ASSERT(block_kv == 32 or block_kv == 64);
DG_HOST_ASSERT(num_heads_kv == 1 and head_dim_with_sf == head_dim + static_cast<int>(sizeof(float)));
DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf and fused_kv_cache.stride(3) == 1);
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
// Derive FP8 values and SF tensor
kv_cache_stride_bytes = fused_kv_cache.stride(0);
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0);
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim},
{kv_cache_stride_bytes, head_dim, 1},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 1},
torch::TensorOptions().dtype(torch::kFloat32)
);
// Weights must be contiguous for FP8
DG_HOST_ASSERT(weights.is_contiguous());
}
DG_HOST_ASSERT(batch_size == batch_size_);
DG_HOST_ASSERT(batch_size_next_n == batch_size * next_n);
DG_HOST_ASSERT(num_heads == num_heads_ and num_heads_kv == 1);
DG_HOST_ASSERT(head_dim_with_sf == head_dim + static_cast<int>(sizeof(float)));
DG_HOST_ASSERT(schedule_meta_size == num_sms + 1 and meta_info_size == 2);
DG_HOST_ASSERT(next_n == 1 or next_n == 2);
DG_HOST_ASSERT(block_kv == 64);
DG_HOST_ASSERT(q.is_contiguous());
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0);
DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf);
DG_HOST_ASSERT(fused_kv_cache.stride(2) == head_dim_with_sf);
DG_HOST_ASSERT(fused_kv_cache.stride(3) == 1);
DG_HOST_ASSERT(weights.is_contiguous());
DG_HOST_ASSERT(context_lens.is_contiguous());
DG_HOST_ASSERT(block_table.stride(1) == 1);
DG_HOST_ASSERT(schedule_meta.is_contiguous());
DG_HOST_ASSERT(q.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
// Check weights
auto [_batch_size_next_n, _num_heads] = get_shape<2>(weights);
DG_HOST_ASSERT(_batch_size_next_n == batch_size * next_n and _num_heads == num_heads);
DG_HOST_ASSERT(weights.stride(1) == 1);
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
// Check block table
auto [_batch_size, _max_block_len] = get_shape<2>(block_table);
DG_HOST_ASSERT(_batch_size == batch_size);
DG_HOST_ASSERT(block_table.stride(1) == 1);
DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt);
// Check schedule metadata
auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta);
DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2);
DG_HOST_ASSERT(schedule_meta.is_contiguous());
DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt);
// Derive FP8 values and SF tensor from KV cache
const auto& kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim},
{kv_cache_stride_bytes, head_dim, 1},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn)
);
const auto& kv_cache_scales = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 1},
torch::TensorOptions().dtype(torch::kFloat32)
);
// Check context lengths
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const auto [__batch_size, _next_n] = get_shape<2>(context_lens);
DG_HOST_ASSERT(batch_size == __batch_size and next_n == _next_n);
DG_HOST_ASSERT(context_lens.is_contiguous());
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
// Allocate output
constexpr int split_kv = 256;
const auto& aligned_max_context_len = align(max_context_len, split_kv);
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q.options().dtype(torch::kFloat));
const auto aligned_max_context_len = align(max_context_len, split_kv);
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q_fp.options().dtype(logits_dtype));
logits = logits.slice(-1, 0, max_context_len);
DG_HOST_ASSERT(logits_dtype == torch::kFloat32 or logits_dtype == torch::kBFloat16);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
if (arch_major == 9 or arch_major == 10) {
smxx_fp8_paged_mqa_logits(q, kv_cache, kv_cache_scales, weights, context_lens, logits, block_table, schedule_meta,
batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
kv_cache_stride_bytes, aligned_max_context_len, block_table_stride, num_sms, split_kv);
const auto arch_major = device_runtime->get_arch_major();
if (is_fp4 and arch_major == 10) {
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
@@ -253,6 +365,32 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
return logits;
}
// Legacy API wrappers
static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
const std::tuple<torch::Tensor, torch::Tensor>& kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const bool& clean_logits,
const int& max_seqlen_k) {
return fp8_fp4_mqa_logits(std::make_tuple(q, std::nullopt), kv, weights,
cu_seq_len_k_start, cu_seq_len_k_end,
clean_logits, max_seqlen_k, torch::kFloat);
}
static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& fused_kv_cache,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits) {
return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights,
context_lens, block_table, schedule_meta,
max_context_len, clean_logits, torch::kFloat);
}
#endif
static void register_apis(pybind11::module_& m) {
@@ -262,13 +400,26 @@ static void register_apis(pybind11::module_& m) {
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_fp4_mqa_logits", &fp8_fp4_mqa_logits,
py::arg("q"), py::arg("kv"), py::arg("weights"),
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
py::arg("clean_logits") = true,
py::arg("max_seqlen_k") = 0,
py::arg("logits_dtype") = torch::kFloat32);
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"));
m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
py::arg("max_context_len"),
py::arg("clean_logits") = false,
py::arg("logits_dtype") = torch::kFloat32);
// Legacy API
m.def("fp8_mqa_logits", &fp8_mqa_logits,
py::arg("q"), py::arg("kv"), py::arg("weights"),
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
py::arg("clean_logits") = true,
py::arg("max_seqlen_k") = 0);
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"));
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),

View File

@@ -29,7 +29,7 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(not c.has_value());
const auto& workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
const auto workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(),
c10::cuda::getCurrentCUDAStream()));
bmk_bnk_mn(a, b, workspace, workspace);
@@ -43,12 +43,12 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor
DG_HOST_ASSERT(b.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
const auto& [s , m, k ] = get_shape<3>(a);
const auto& [s_, n, k_] = get_shape<3>(b);
const auto [s , m, k ] = get_shape<3>(a);
const auto [s_, n, k_] = get_shape<3>(b);
DG_HOST_ASSERT(s == s_ and k == k_);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else if (arch_major == 10) {
@@ -59,9 +59,9 @@ static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const tor
}
static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) {
const auto& [b , h , r ] = get_shape<3>(A);
const auto& [h_, d , r_] = get_shape<3>(B);
const auto& [b_, h__, d_] = get_shape<3>(D);
const auto [b , h , r ] = get_shape<3>(A);
const auto [h_, d , r_] = get_shape<3>(B);
const auto [b_, h__, d_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
@@ -69,7 +69,7 @@ static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const to
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (use_cublaslt) {
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
@@ -82,9 +82,9 @@ static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const to
}
static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) {
const auto& [b , h , d ] = get_shape<3>(A);
const auto& [h_, d_ , r ] = get_shape<3>(B);
const auto& [b_, h__, r_] = get_shape<3>(D);
const auto [b , h , d ] = get_shape<3>(A);
const auto [h_, d_ , r ] = get_shape<3>(B);
const auto [b_, h__, r_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
@@ -92,7 +92,7 @@ static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const to
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (use_cublaslt) {
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
@@ -142,16 +142,16 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims) {
// Shape must be `[B, M, K] @ [B, N, K].T`
const auto& major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
const auto& major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
const auto major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
const auto major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
DG_HOST_ASSERT(a.stride(-1) == 1 or a.stride(-2) == 1);
DG_HOST_ASSERT(b.stride(-1) == 1 or b.stride(-2) == 1);
DG_HOST_ASSERT(d.stride(-1) == 1);
// Type and shape checks
const auto& [batch_size , m , k ] = get_shape<3>(a);
const auto& [batch_size_ , n , k_] = get_shape<3>(b);
const auto& [batch_size__, m_, n_] = get_shape<3>(d);
const auto [batch_size , m , k ] = get_shape<3>(a);
const auto [batch_size_ , n , k_] = get_shape<3>(b);
const auto [batch_size__, m_, n_] = get_shape<3>(d);
DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size_);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn);
@@ -163,15 +163,16 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
return;
// Transform scaling factors
const auto& [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
const auto [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, compiled_dims);
sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims);
} else {
const auto& major_sfb = get_major_type_ab(sfb);
const auto major_sfb = get_major_type_ab(sfb);
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
}
@@ -187,26 +188,26 @@ static void fp8_einsum(const std::string& expr,
if (expr == "bhr,hdr->bhd") {
// Permute dims to satisfy the order of (batch_size, m, n, k)
// (batch_size, m, n, k): (h, b, d, r)
const auto& perm_a = a.first.permute({1, 0, 2});
const auto& perm_sfa = a.second.permute({1, 0, 2});
const auto& perm_d = d.permute({1, 0, 2});
const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
const auto perm_a = a.first.permute({1, 0, 2});
const auto perm_sfa = a.second.permute({1, 0, 2});
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,hdr->bhr" and arch_major == 10) {
// (batch_size, m, n, k): (h, b, r, d)
const auto& perm_a = a.first.permute({1, 0, 2});
const auto& perm_sfa = a.second.permute({1, 0, 2});
const auto& perm_b = b.first.permute({0, 2, 1});
const auto& perm_sfb = b.second.permute({0, 2, 1});
const auto& perm_d = d.permute({1, 0, 2});
const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
const auto perm_a = a.first.permute({1, 0, 2});
const auto perm_sfa = a.second.permute({1, 0, 2});
const auto perm_b = b.first.permute({0, 2, 1});
const auto perm_sfb = b.second.permute({0, 2, 1});
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,bhr->hdr" and arch_major == 10) {
// (batch_size, m, n, k): (h, d, r, b)
const auto& perm_a = a.first.permute({1, 2, 0});
const auto& perm_sfa = a.second.permute({1, 2, 0});
const auto& perm_b = b.first.permute({1, 2, 0});
const auto& perm_sfb = b.second.permute({1, 2, 0});
const auto perm_a = a.first.permute({1, 2, 0});
const auto perm_sfa = a.second.permute({1, 2, 0});
const auto perm_b = b.first.permute({1, 2, 0});
const auto perm_sfb = b.second.permute({1, 2, 0});
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn");
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));

View File

@@ -6,7 +6,7 @@
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
#endif
@@ -23,7 +23,7 @@ static bool early_return(const int& m, const int &n, const int& k,
return true;
// Checks
const bool& is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr();
const bool is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr();
if (is_cd_same)
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
if (c.has_value()) {
@@ -57,8 +57,8 @@ static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
@@ -89,7 +89,7 @@ static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
if (gran_n == 1) {
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
const auto& major_sfb = get_major_type_ab(sfb);
const auto major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
@@ -152,8 +152,8 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
if (fp8_requires_k_major())
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
@@ -171,10 +171,10 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor,
// Layout checks
if (use_psum_layout) {
const auto& [num_groups_] = get_shape<1>(grouped_layout);
const auto [num_groups_] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto& [m__] = get_shape<1>(grouped_layout);
const auto [m__] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
@@ -192,10 +192,10 @@ static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor,
// Dispatch implementation
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto& major_sfb = get_major_type_ab(sfb);
DG_HOST_ASSERT(not use_psum_layout);
const auto major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, major_sfb, compiled_dims);
num_groups, m, n, k, major_a, major_b, major_sfb,
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b,
@@ -230,8 +230,8 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a.first);
const auto& major_b = get_major_type_ab(b.first);
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());
@@ -256,7 +256,7 @@ static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torc
// Dispatch implementation
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto& major_sfb = get_major_type_ab(sfb);
const auto major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
@@ -277,12 +277,15 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
const std::tuple<int, int, int>& recipe,
const std::string& compiled_dims) {
// Must be 1D1D kernel
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1);
const int gran_k = std::get<2>(recipe);
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& [sum_k_ , m_] = get_shape<2>(a.first);
const auto& [sum_k__, n_] = get_shape<2>(b.first);
const auto [num_groups, m, n] = get_shape<3>(d);
const auto [sum_k_ , m_] = get_shape<2>(a.first);
const auto [sum_k__, n_] = get_shape<2>(b.first);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
@@ -297,13 +300,13 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torc
return;
// Transform SF with padding
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor,
sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, gran_k,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
@@ -322,9 +325,9 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& sum_mk = a.first.numel();
const auto& sum_nk = b.first.numel();
const auto [num_groups, m, n] = get_shape<3>(d);
const auto sum_mk = a.first.numel();
const auto sum_nk = b.first.numel();
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(sum_mk == static_cast<int64_t>(sum_k) * m);
DG_HOST_ASSERT(sum_nk == static_cast<int64_t>(sum_k) * n);
@@ -340,17 +343,17 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torc
return;
// Transform SF with padding
const auto& sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto& sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
// Allocate tensormap buffer
// `4` means the double buffering for both A and B operands (2 * 2)
const auto& num_sms = device_runtime->get_num_sms();
const auto& tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
a.first.options().dtype(torch::kByte));
const auto num_sms = device_runtime->get_num_sms();
const auto tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
a.first.options().dtype(torch::kByte));
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
@@ -367,16 +370,16 @@ static void bf16_gemm_nt(const torch::Tensor& a,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a);
const auto& major_b = get_major_type_ab(b);
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
// C/D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a);
const auto& [n , k_] = get_shape<2>(b);
const auto& [m_, n_] = get_shape<2>(d);
const auto [m , k ] = get_shape<2>(a);
const auto [n , k_] = get_shape<2>(b);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
@@ -387,7 +390,7 @@ static void bf16_gemm_nt(const torch::Tensor& a,
return;
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10) {
@@ -427,15 +430,15 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a);
const auto& major_b = get_major_type_ab(b);
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(grouped_layout.is_contiguous());
// Type and shape checks
const auto& [m, k] = get_shape<2>(a);
const auto& [num_groups, n, k_] = get_shape<3>(b);
const auto& [m_, n_] = get_shape<2>(d);
const auto [m, k] = get_shape<2>(a);
const auto [num_groups, n, k_] = get_shape<3>(b);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
@@ -445,10 +448,10 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc
// Layout checks
if (use_psum_layout) {
const auto& [num_groups_] = get_shape<1>(grouped_layout);
const auto [num_groups_] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto& [m__] = get_shape<1>(grouped_layout);
const auto [m__] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
@@ -461,11 +464,11 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc
return;
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
DG_HOST_ASSERT(not use_psum_layout);
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, compiled_dims);
num_groups, m, n, k, major_a, major_b, compiled_dims,
use_psum_layout, expected_m_for_psum_layout);
} else if (arch_major == 10) {
sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, compiled_dims,
@@ -487,16 +490,16 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
const torch::Tensor& d, const torch::Tensor& masked_m,
const int& expected_m, const std::string& compiled_dims) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto& major_a = get_major_type_ab(a);
const auto& major_b = get_major_type_ab(b);
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());
// Type and shape checks
const auto& [num_groups, m, k] = get_shape<3>(a);
const auto& [num_groups_, n, k_] = get_shape<3>(b);
const auto& [num_groups__, m_, n_] = get_shape<3>(d);
const auto& num_groups___ = static_cast<int>(masked_m.numel());
const auto [num_groups, m, k] = get_shape<3>(a);
const auto [num_groups_, n, k_] = get_shape<3>(b);
const auto [num_groups__, m_, n_] = get_shape<3>(d);
const auto num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
@@ -509,7 +512,7 @@ static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::T
check_major_type_cd(d);
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
@@ -529,9 +532,9 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
// Shape checks
const auto& [num_groups, m, n] = get_shape<3>(d);
const auto& [sum_k_ , m_] = get_shape<2>(a);
const auto& [sum_k__, n_] = get_shape<2>(b);
const auto [num_groups, m, n] = get_shape<3>(d);
const auto [sum_k_ , m_] = get_shape<2>(a);
const auto [sum_k__, n_] = get_shape<2>(b);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
@@ -546,7 +549,7 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
return;
// Dispatch implementation
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
@@ -562,20 +565,20 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
// Shape must be `[M, K] @ [N, K].T`
const auto& major_a = get_major_type_ab(a);
const auto& major_b = get_major_type_ab(b);
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
// Type and shape checks
const auto& [m , k ] = get_shape<2>(a);
const auto& [n , k_] = get_shape<2>(b);
const auto& [m_, n_] = get_shape<2>(d);
const auto [m , k ] = get_shape<2>(a);
const auto [n , k_] = get_shape<2>(b);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
// Early return for trivial cases
if (early_return(m, n, k, d, c))
return;
cublaslt_gemm(a, b, c, d, m, n, k, major_a, major_b);
cublaslt_gemm(a, b, d, m, n, k, major_a, major_b, c.has_value());
}
static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b,

View File

@@ -24,16 +24,16 @@ static void tf32_hc_prenorm_gemm(const torch::Tensor& a,
DG_HOST_ASSERT(sqr_sum.is_contiguous());
// Type and shape checks
const auto& [m, k ] = get_shape<2>(a);
const auto& [n, k_] = get_shape<2>(b);
const auto [m, k ] = get_shape<2>(a);
const auto [n, k_] = get_shape<2>(b);
if (num_splits.has_value()) {
const auto& [num_splits_, m_, n_] = get_shape<3>(d);
const auto& [num_splits__, m__] = get_shape<2>(sqr_sum);
const auto [num_splits_, m_, n_] = get_shape<3>(d);
const auto [num_splits__, m__] = get_shape<2>(sqr_sum);
DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1);
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
} else {
const auto& [m_, n_] = get_shape<2>(d);
const auto& [m__] = get_shape<1>(sqr_sum);
const auto [m_, n_] = get_shape<2>(d);
const auto [m__] = get_shape<1>(sqr_sum);
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
}
DG_HOST_ASSERT(n > 0 and k > 0);
@@ -47,7 +47,7 @@ static void tf32_hc_prenorm_gemm(const torch::Tensor& a,
return;
// Dispatch into different implements
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
} else if (arch_major == 10) {

View File

@@ -1,5 +1,6 @@
#pragma once
#include "../jit_kernels/heuristics/runtime.hpp"
#include "../utils/layout.hpp"
#include "../utils/compatibility.hpp"
@@ -12,21 +13,24 @@ namespace deep_gemm::layout {
#if DG_TENSORMAP_COMPATIBLE
static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_ab,
const std::variant<std::tuple<int, int, int>,
std::tuple<int, int>>& recipe,
const std::optional<int>& num_groups,
const bool& is_sfa,
const std::optional<bool>& is_sfa,
const bool& disable_ue8m0_cast) {
const auto& arch_major = device_runtime->get_arch_major();
const auto arch_major = device_runtime->get_arch_major();
// Get granularity MN/K from recipe
int gran_mn, gran_k;
if (recipe.has_value()) {
DG_HOST_ASSERT(not recipe_ab.has_value());
gran_mn = is_sfa ? std::get<0>(recipe.value()) : std::get<1>(recipe.value());
gran_k = std::get<2>(recipe.value());
if (auto p = std::get_if<std::tuple<int, int, int>>(&recipe)) {
DG_HOST_ASSERT(is_sfa.has_value());
gran_mn = is_sfa.value() ? std::get<0>(*p) : std::get<1>(*p);
gran_k = std::get<2>(*p);
} else if (auto p = std::get_if<std::tuple<int, int>>(&recipe)) {
DG_HOST_ASSERT(not is_sfa.has_value());
std::tie(gran_mn, gran_k) = *p;
} else {
DG_HOST_ASSERT(recipe_ab.has_value());
std::tie(gran_mn, gran_k) = recipe_ab.value();
DG_HOST_UNREACHABLE("Invalid recipe");
}
// Pre-transform checks
@@ -43,8 +47,8 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
// (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
const auto& broadcasted = gran_mn == 1 ? sf :
sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn));
const auto broadcasted = gran_mn == 1 ? sf :
sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn));
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
}
@@ -64,11 +68,19 @@ static std::tuple<torch::Tensor, torch::Tensor, int, int> transform_sf_pair_into
const std::optional<int>& num_groups_a,
const std::optional<int>& num_groups_b,
const bool& disable_ue8m0_cast = false) {
DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value());
// Use default recipe, if none is specified
if (not recipe_a.has_value() and not recipe.has_value())
recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type());
const auto transformed_sfa = transform_sf_into_required_layout(sfa, m, k, recipe, recipe_a, num_groups_a, true, disable_ue8m0_cast);
const auto transformed_sfb = transform_sf_into_required_layout(sfb, n, k, recipe, recipe_b, num_groups_b, false, disable_ue8m0_cast);
// Must be either 'recipe' or the 'recipe_a' + 'recipe_b' pair.
DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value());
DG_HOST_ASSERT(recipe_a.has_value() != recipe.has_value());
// Transform SFA and SFB layout
const auto transformed_sfa = recipe.has_value() ? transform_sf_into_required_layout(sfa, m, k, recipe.value(), num_groups_a, true, disable_ue8m0_cast)
: transform_sf_into_required_layout(sfa, m, k, recipe_a.value(), num_groups_a, std::nullopt, disable_ue8m0_cast);
const auto transformed_sfb = recipe.has_value() ? transform_sf_into_required_layout(sfb, n, k, recipe.value(), num_groups_b, false, disable_ue8m0_cast)
: transform_sf_into_required_layout(sfb, n, k, recipe_b.value(), num_groups_b, std::nullopt, disable_ue8m0_cast);
const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value());
const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value());
return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b);
@@ -79,8 +91,12 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
const torch::Tensor& ks_tensor,
const std::tuple<int, int, int>& recipe) {
DG_HOST_ASSERT(sf.dim() == 2);
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
const auto& arch_major = device_runtime->get_arch_major();
DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1);
const int gran_k = std::get<2>(recipe);
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
const auto arch_major = device_runtime->get_arch_major();
// FP32 on SM90
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
@@ -88,7 +104,7 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
// FP32 on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks);
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks, gran_k);
// INT on SM100
if (sf.scalar_type() == torch::kInt and arch_major == 10)
@@ -100,12 +116,11 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
#endif
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
py::arg("sf"), py::arg("mn"), py::arg("k"),
py::arg("recipe") = std::nullopt, py::arg("recipe_ab") = std::nullopt,
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
py::arg("num_groups") = std::nullopt,
py::arg("is_sfa") = std::nullopt,
py::arg("disable_ue8m0_cast") = false);
m.def("get_tma_aligned_size", &get_tma_aligned_size);
@@ -114,7 +129,15 @@ static void register_apis(pybind11::module_& m) {
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
#endif
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
m.def("set_mk_alignment_for_contiguous_layout", [&](const int& new_value) {
heuristics_runtime->set_mk_alignment_for_contiguous_layout(new_value);
});
m.def("get_mk_alignment_for_contiguous_layout", [&]() {
return heuristics_runtime->get_mk_alignment_for_contiguous_layout();
});
m.def("get_theoretical_mk_alignment_for_contiguous_layout", [&](const std::optional<int>& expected_m) {
return heuristics_runtime->get_theoretical_mk_alignment_for_contiguous_layout(expected_m);
}, py::arg("expected_m") = std::nullopt);
}
} // namespace deep_gemm::layout

216
csrc/apis/mega.hpp Normal file
View File

@@ -0,0 +1,216 @@
#pragma once
#include <functional>
#include <pybind11/functional.h>
#if DG_TENSORMAP_COMPATIBLE
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp"
namespace deep_gemm::mega {
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
get_symm_buffer_size_for_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const bool& use_fp8_dispatch, const std::string& activation) {
DG_HOST_ASSERT(num_experts % num_ranks == 0);
// Workspace bytes
const auto block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk, block_m);
// Layouts
const auto fp8_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp8_sf_layout = layout::Data(hidden / 32);
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
// Input buffers
const auto input_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_tokens_per_rank,
workspace.get_end_ptr());
const auto input_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_max_tokens_per_rank,
input_token_buffer.get_end_ptr());
const auto input_topk_idx_buffer = layout::Buffer(
input_topk_idx_layout, 1, num_max_tokens_per_rank,
input_sf_buffer.get_end_ptr());
const auto input_topk_weights_buffer = layout::Buffer(
input_topk_weights_layout, 1, num_max_tokens_per_rank,
input_topk_idx_buffer.get_end_ptr());
// Buffer configs
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
const auto num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m);
// L1 input buffer
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_pool_tokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_padded_sf_pool_tokens,
l1_token_buffer.get_end_ptr());
const auto l1_topk_weights_buffer = layout::Buffer(
l1_topk_weights_layout, 1, num_max_pool_tokens,
l1_sf_buffer.get_end_ptr());
// L2 input buffer
const auto l2_token_buffer = layout::Buffer(
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
l1_topk_weights_buffer.get_end_ptr());
const auto l2_sf_buffer = layout::Buffer(
fp8_intermediate_sf_layout, 1, num_padded_sf_pool_tokens,
l2_token_buffer.get_end_ptr());
// Combine input buffer: BF16 tokens for cross-rank combine
const auto combine_token_buffer = layout::Buffer(
bf16_token_layout, num_topk, num_max_tokens_per_rank,
l2_sf_buffer.get_end_ptr());
// Check SF buffer requirements
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(num_padded_sf_pool_tokens % 4 == 0);
// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
{num_max_tokens_per_rank, hidden / 128},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto topk_idx = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
auto topk_weights = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_padded_sf_pool_tokens, hidden / 128},
{1, num_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
{num_padded_sf_pool_tokens, intermediate_hidden / 128},
{1, num_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
}
static void fp8_fp4_mega_moe(
const torch::Tensor& y,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_,
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_,
const torch::Tensor& sym_buffer,
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
const int& num_max_tokens_per_rank,
const int& num_experts, const int& num_topk,
const std::tuple<int, int, int>& recipe,
const std::string& activation,
const std::optional<float>& activation_clamp_opt,
const bool& fast_math) {
const auto [l1_weights, l1_weights_sf] = l1_weights_;
const auto [l2_weights, l2_weights_sf] = l2_weights_;
// Config checks
const auto num_tokens = static_cast<int>(y.size(0));
const auto [rm, rn, rk] = recipe;
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32);
DG_HOST_ASSERT(activation == "swiglu");
// Activation checks
const auto activation_clamp =
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
DG_HOST_ASSERT(activation_clamp >= 0);
// Tensor checks
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
const auto arch_major = device_runtime->get_arch_major();
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] =
check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major);
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] =
check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major);
DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank);
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
DG_HOST_ASSERT(hidden == hidden_);
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
constexpr int kGranMN = 1, kGranK = 32;
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
// Check buffer bytes
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts_ = num_experts_per_rank * num_ranks;
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe(
num_ranks, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
true, "swiglu");
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
DG_HOST_ASSERT(num_experts == num_experts_);
// Already registered tensors
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
// Dispatch into different architectures
if (arch_major == 10) {
sm100_fp8_fp4_mega_moe(y,
l1_acts, l1_acts_sf,
l2_acts, l2_acts_sf,
l1_weights, l2_weights,
l1_weights_sf, l2_weights_sf,
sym_buffer_ptrs,
rank_idx, num_max_tokens_per_rank,
num_experts_per_rank,
num_tokens, num_topk,
hidden, intermediate_hidden,
activation_clamp, fast_math);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Zero the entire symmetric buffer for debug mode
// NOTES: caller must re-copy inputs into the buffer before each kernel call
if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
sym_buffer.zero_();
}
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("get_block_m_for_mega_moe", &get_block_m_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
#endif
}
} // namespace deep_gemm::mega

View File

@@ -4,6 +4,7 @@
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/heuristics/runtime.hpp"
namespace deep_gemm::runtime {
@@ -20,10 +21,29 @@ static void register_apis(pybind11::module_& m) {
m.def("get_tc_util", [&]() {
return device_runtime->get_tc_util();
});
m.def("set_pdl", [&](const bool& new_enable_pdl) {
device_runtime->set_pdl(new_enable_pdl);
});
m.def("get_pdl", [&]() {
return device_runtime->get_pdl();
});
m.def("set_ignore_compile_dims", [&](const bool& new_value) {
heuristics_runtime->set_ignore_compile_dims(new_value);
});
m.def("set_block_size_multiple_of", [&](const std::variant<int, std::tuple<int, int>>& new_value) {
if (std::holds_alternative<int>(new_value)) {
auto x = std::get<int>(new_value);
heuristics_runtime->set_block_size_multiple_of(x, x);
} else {
auto [x, y] = std::get<std::tuple<int, int>>(new_value);
heuristics_runtime->set_block_size_multiple_of(x, y);
}
});
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
#if DG_TENSORMAP_COMPATIBLE
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
KernelRuntime::prepare_init(cuda_home_path_by_python);
IncludeParser::prepare_init(library_root_path);
#endif
});
}

View File

@@ -3,12 +3,14 @@
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh>
// Attention kernels
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp4_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp8_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh>
// Einsum kernels
@@ -23,6 +25,9 @@
#include <deep_gemm/impls/smxx_layout.cuh>
#include <deep_gemm/impls/smxx_clean_logits.cuh>
// Mega kernels
#include <deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh>
using namespace deep_gemm;
int main() {

View File

@@ -17,7 +17,7 @@ public:
std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
// Hit the runtime cache
if (const auto& iterator = cache.find(dir_path); iterator != cache.end())
if (const auto iterator = cache.find(dir_path); iterator != cache.end())
return iterator->second;
if (KernelRuntime::check_validity(dir_path))

View File

@@ -2,6 +2,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <fcntl.h>
#include <filesystem>
#include <fstream>
#include <nvrtc.h>
@@ -15,6 +16,7 @@
#include "../utils/system.hpp"
#include "cache.hpp"
#include "device_runtime.hpp"
#include "include_parser.hpp"
namespace deep_gemm {
@@ -23,29 +25,13 @@ public:
static std::filesystem::path library_root_path;
static std::filesystem::path library_include_path;
static std::filesystem::path cuda_home;
static std::string library_version;
static std::filesystem::path cuobjdump_path;
static std::string get_library_version() {
std::vector<char> buffer;
for (const auto& f: collect_files(library_include_path / "deep_gemm")) {
std::ifstream in(f, std::ios::binary);
DG_HOST_ASSERT(in.is_open());
// Append into the buffer
buffer.insert(buffer.end(),
std::istreambuf_iterator<char>(in),
std::istreambuf_iterator<char>());
}
return get_hex_digest(buffer);
}
static void prepare_init(const std::string& library_root_path,
const std::string& cuda_home_path_by_python) {
Compiler::library_root_path = library_root_path;
Compiler::library_include_path = Compiler::library_root_path / "include";
Compiler::cuda_home = cuda_home_path_by_python;
Compiler::library_version = get_library_version();
Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump";
}
@@ -57,12 +43,11 @@ public:
DG_HOST_ASSERT(not library_root_path.empty());
DG_HOST_ASSERT(not library_include_path.empty());
DG_HOST_ASSERT(not cuda_home.empty());
DG_HOST_ASSERT(not library_version.empty());
DG_HOST_ASSERT(not cuobjdump_path.empty());
// Cache settings
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
if (const auto& env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty())
if (const auto env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty())
cache_dir_path = env_cache_dir_path;
// The compiler flags applied to all derived compilers
@@ -82,58 +67,79 @@ public:
return make_dirs(cache_dir_path / "tmp");
}
std::filesystem::path get_tmp_file_path() const {
return make_tmp_dir() / get_uuid();
static void fsync_path(const std::filesystem::path& path) {
const auto fd = ::open(path.c_str(), O_RDONLY);
if (fd >= 0) {
::fsync(fd);
::close(fd);
}
}
void put(const std::filesystem::path& path, const std::string& data) const {
const auto tmp_file_path = get_tmp_file_path();
// Recursively fsync a directory: files and subdirectories first (bottom-up), then the directory itself
// NOTES: ensures data and directory entries are visible on other nodes in distributed filesystems
static void fsync_dir(const std::filesystem::path& dir_path) { // NOLINT(*-no-recursion)
for (const auto& entry: std::filesystem::directory_iterator(dir_path)) {
if (entry.is_directory())
fsync_dir(entry.path());
else if (entry.is_regular_file())
fsync_path(entry.path());
}
fsync_path(dir_path);
}
// Write into the temporary file
std::ofstream out(tmp_file_path, std::ios::binary);
static void put(const std::filesystem::path& path, const std::string& data) {
std::ofstream out(path, std::ios::binary);
DG_HOST_ASSERT(out.write(data.data(), data.size()));
out.close();
// Atomically replace
std::filesystem::rename(tmp_file_path, path);
// NOTES: fsync to ensure the data is visible to other processes (e.g., NVCC)
// on distributed filesystems, where `close()` alone does not guarantee persistence
fsync_path(path);
}
std::shared_ptr<KernelRuntime> build(const std::string& name, const std::string& code) const {
const auto kernel_signature = fmt::format("{}$${}$${}$${}$${}", name, library_version, signature, flags, code);
const auto kernel_signature = fmt::format("{}$${}$${}$${}", name, signature, flags, code);
const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature));
// Hit the runtime cache
if (const auto& runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr)
if (const auto runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr)
return runtime;
// Create the kernel directory
make_dirs(dir_path);
// Compile into a temporary directory, then atomically rename the whole directory
// NOTES: renaming a directory is atomic on both local and distributed filesystems,
// avoiding the stale inode issue that occurs when renaming individual files
const auto tmp_dir_path = make_tmp_dir() / get_uuid();
make_dirs(tmp_dir_path);
// Compile into a temporary CUBIN
const auto tmp_cubin_path = get_tmp_file_path();
// Compile into the temporary directory
const auto tmp_cubin_path = tmp_dir_path / "kernel.cubin";
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_PTX")) {
// Dump PTX if needed
const auto tmp_ptx_path = get_tmp_file_path();
compile(code, dir_path, tmp_cubin_path, tmp_ptx_path);
// Replace into the cache directory
std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx");
const auto tmp_ptx_path = tmp_dir_path / "kernel.ptx";
compile(code, tmp_dir_path, tmp_cubin_path, tmp_ptx_path);
} else {
compile(code, dir_path, tmp_cubin_path);
compile(code, tmp_dir_path, tmp_cubin_path);
}
// Replace into the cache directory
const auto cubin_path = dir_path / "kernel.cubin";
std::filesystem::rename(tmp_cubin_path, cubin_path);
// Disassemble if needed
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_SASS")) {
// Dump into a temporary SASS
const auto tmp_sass_path = get_tmp_file_path();
disassemble(cubin_path, tmp_sass_path);
const auto tmp_sass_path = tmp_dir_path / "kernel.sass";
disassemble(tmp_cubin_path, tmp_sass_path);
}
// Replace into the current directory
std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass");
// Fsync before rename to ensure visibility on distributed filesystems
fsync_dir(tmp_dir_path);
// Atomically rename the temporary directory to the final cache path
// NOTES: if another rank already created dir_path, rename will fail — that's fine
make_dirs(dir_path.parent_path());
std::error_code error_code;
std::filesystem::rename(tmp_dir_path, dir_path, error_code);
if (error_code) {
// Another rank beat us, then clean up our dir and use the existing one
// NOTES: avoid `std::filesystem::remove_all` here — it can segfault on
// distributed filesystems, when concurrent processes operate
// on the same parent directory, causing stale directory entries
safe_remove_all(tmp_dir_path);
}
// Put into the runtime cache
@@ -160,7 +166,6 @@ public:
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path);
class NVCCCompiler final: public Compiler {
@@ -170,8 +175,8 @@ class NVCCCompiler final: public Compiler {
DG_HOST_ASSERT(std::filesystem::exists(nvcc_path));
// Call the version command
const auto& command = std::string(nvcc_path) + " --version";
const auto& [return_code, output] = call_external_command(command);
const auto command = std::string(nvcc_path) + " --version";
const auto [return_code, output] = call_external_command(command);
DG_HOST_ASSERT(return_code == 0);
// The version should be at least 12.3, for the best performance with 12.9
@@ -189,14 +194,14 @@ public:
NVCCCompiler() {
// Override the compiler signature
nvcc_path = cuda_home / "bin" / "nvcc";
if (const auto& env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
if (const auto env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
nvcc_path = env_nvcc_path;
const auto& [nvcc_major, nvcc_minor] = get_nvcc_version();
const auto [nvcc_major, nvcc_minor] = get_nvcc_version();
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
// The override the compiler flags
// Only NVCC >= 12.9 supports arch-specific family suffix
const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9);
const auto arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9);
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
"-O3 --expt-relaxed-constexpr --expt-extended-lambda",
@@ -207,14 +212,17 @@ public:
const std::filesystem::path &cubin_path,
const std::optional<std::filesystem::path> &ptx_path) const override {
// Write the code into the cache directory
const auto& code_path = dir_path / "kernel.cu";
const auto code_path = dir_path / "kernel.cu";
put(code_path, code);
// Compile
const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
// Avoid cwd files shadowing C++ standard library headers
const auto compile_dir = make_tmp_dir();
const auto command = fmt::format("cd {} && {} {} -cubin -o {} {}",
compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC command: %s\n", command.c_str());
const auto& [return_code, output] = call_external_command(command);
const auto [return_code, output] = call_external_command(command);
if (return_code != 0) {
printf("NVCC compilation failed: %s\n", output.c_str());
DG_HOST_ASSERT(false and "NVCC compilation failed");
@@ -222,7 +230,8 @@ public:
// Compile to PTX if needed
if (ptx_path.has_value()) {
const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags);
const auto ptx_command = fmt::format("cd {} && {} {} -ptx -o {} {}",
compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC PTX command: %s\n", ptx_command.c_str());
const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command);
@@ -267,7 +276,7 @@ public:
// Override the compiler flags
// Only NVRTC >= 12.9 supports arch-specific family suffix
const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9);
const auto arch = device_runtime->get_arch(false, major > 12 or minor >= 9);
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128",
flags, include_dirs, arch, pch_flags);
}
@@ -276,7 +285,7 @@ public:
const std::filesystem::path &cubin_path,
const std::optional<std::filesystem::path> &ptx_path) const override {
// Write the code into the cache directory
const auto& code_path = dir_path / "kernel.cu";
const auto code_path = dir_path / "kernel.cu";
put(code_path, code);
// Parse compilation options
@@ -302,7 +311,7 @@ public:
// Create NVRTC program and compile
nvrtcProgram program;
DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr));
const auto& compile_result = nvrtcCompileProgram(program, static_cast<int>(option_cstrs.size()), option_cstrs.data());
const auto compile_result = nvrtcCompileProgram(program, static_cast<int>(option_cstrs.size()), option_cstrs.data());
// Get and print compiler log
size_t log_size;

View File

@@ -7,10 +7,13 @@
#include "../utils/exception.hpp"
#include "../utils/lazy_init.hpp"
#define PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3))
namespace deep_gemm {
class DeviceRuntime {
int num_sms = 0, tc_util = 0;
bool enable_pdl = false;
std::shared_ptr<cudaDeviceProp> cached_prop;
// cuBLASLt utils
@@ -18,24 +21,52 @@ class DeviceRuntime {
public:
// Create the cuBLASLt handle ourselves
cublasLtHandle_t cublaslt_handle{};
std::shared_ptr<torch::Tensor> cublaslt_workspace;
cublasLtHandle_t cublaslt_handle;
torch::Tensor cublaslt_workspace;
bool use_pytorch_managed_cublaslt_handle;
bool use_temp_cublaslt_workspace;
explicit DeviceRuntime() {
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
// Whether to use PyTorch cuBLASLt
// By default, we don't use it,
// as `at::cuda::getCurrentCUDABlasLtHandle` has large CPU overhead with some PyTorch versions
use_pytorch_managed_cublaslt_handle = get_env<int>("DG_USE_PYTORCH_CUBLASLT_HANDLE", 0) > 0;
#if not PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE
DG_HOST_ASSERT(not use_pytorch_managed_cublaslt_handle and "PyTorch does not support to get cuBLASLt handle");
#endif
// Whether to create workspace tensor on each call instead of holding one.
// Enabled by compute-sanitizer tests, which trigger `cudaErrorCudartUnloading`
// when the workspace tensor is destructed after CUDA driver shutdown.
use_temp_cublaslt_workspace = get_env<int>("DG_USE_TEMP_CUBLASLT_WORKSPACE", 0) > 0;
if (not use_pytorch_managed_cublaslt_handle)
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
if (not use_temp_cublaslt_workspace)
cublaslt_workspace = torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA));
}
~DeviceRuntime() noexcept(false) {
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
if (not use_pytorch_managed_cublaslt_handle)
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
}
cublasLtHandle_t get_cublaslt_handle() const {
#if PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE
if (use_pytorch_managed_cublaslt_handle)
return at::cuda::getCurrentCUDABlasLtHandle();
#endif
// Self-managed handle
return cublaslt_handle;
}
torch::Tensor get_cublaslt_workspace() const {
return *cublaslt_workspace;
if (use_temp_cublaslt_workspace)
return torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA));
return cublaslt_workspace;
}
std::shared_ptr<cudaDeviceProp> get_prop() {
@@ -56,7 +87,7 @@ public:
std::string get_arch(const bool& number_only = false,
const bool& support_arch_family = false) {
const auto& [major, minor] = get_arch_pair();
const auto [major, minor] = get_arch_pair();
if (major == 10 and minor != 1) {
if (number_only)
return "100";
@@ -92,6 +123,14 @@ public:
int get_tc_util() const {
return tc_util == 0 ? 100 : tc_util;
}
void set_pdl(const bool& new_enable_pdl) {
enable_pdl = new_enable_pdl;
}
bool get_pdl() const {
return enable_pdl;
}
};
static auto device_runtime = LazyInit<DeviceRuntime>([](){ return std::make_shared<DeviceRuntime>(); });

View File

@@ -24,7 +24,7 @@ static void* get_driver_handle() {
#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \
template <typename... Args> \
static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \
using FuncType = decltype(&name); \
using FuncType = decltype(&(name)); \
static FuncType func = nullptr; \
if (func == nullptr) { \
func = reinterpret_cast<FuncType>(dlsym(get_driver_handle(), #name)); \
@@ -39,6 +39,9 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryLoadFromFile);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryUnload);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuKernelGetFunction);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled);
@@ -65,13 +68,13 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s
}
static void unload_library(const LibraryHandle& library) {
const auto& error = cudaLibraryUnload(library);
const auto error = cudaLibraryUnload(library);
DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading);
}
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
const cudaStream_t& stream, const int& smem_size,
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) {
if (smem_size > 0)
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
@@ -80,17 +83,27 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
config.blockDim = block_dim;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
config.numAttrs = 0;
config.attrs = nullptr;
// Create attributes
// NOTES: must use `static` or the `attr` will be deconstructed
static LaunchAttrHandle attr;
static LaunchAttrHandle attrs[2];
config.numAttrs = 0;
config.attrs = attrs;
// Cluster size
if (cluster_dim > 1) {
auto& attr = attrs[config.numAttrs ++];
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {static_cast<unsigned>(cluster_dim), 1, 1};
config.attrs = &attr;
config.numAttrs = 1;
}
// Dependent kernel launch
if (enable_pdl) {
auto& attr = attrs[config.numAttrs ++];
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
attr.val.programmaticStreamSerializationAllowed = 1;
}
return config;
}
@@ -103,19 +116,46 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle&
#else
// Use CUDA driver API
using LibraryHandle = CUmodule;
using KernelHandle = CUfunction;
using LaunchConfigHandle = CUlaunchConfig;
using LaunchAttrHandle = CUlaunchAttribute;
// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4
#if CUDA_VERSION >= 12040
#define DG_JIT_USE_LIBRARY_ENUM_KERNELS
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels);
using LibraryHandle = CUlibrary;
#else
using LibraryHandle = CUmodule;
#endif
#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK
static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name,
LibraryHandle *library_opt = nullptr) {
LibraryHandle *library_opt = nullptr) {
LibraryHandle library;
KernelHandle kernel;
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
unsigned int num_kernels;
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryGetKernelCount(&num_kernels, library));
if (num_kernels != 1) {
const auto dir_path = cubin_path.parent_path();
printf("Corrupted JIT cache directory (expected 1 kernel, found %u): %s, "
"please run `rm -rf %s` and restart your task.\n",
num_kernels, dir_path.c_str(), dir_path.c_str());
DG_HOST_ASSERT(false and "Corrupted JIT cache directory");
}
CUkernel cu_kernel;
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryEnumerateKernels(&cu_kernel, 1, library));
DG_CUDA_DRIVER_CHECK(lazy_cuKernelGetFunction(&kernel, cu_kernel));
#else
DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str()));
DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str()));
#endif
if (library_opt != nullptr)
*library_opt = library;
@@ -123,13 +163,17 @@ static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const s
}
static void unload_library(const LibraryHandle& library) {
const auto& error = lazy_cuModuleUnload(library);
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
const auto error = lazy_cuLibraryUnload(library);
#else
const auto error = lazy_cuModuleUnload(library);
#endif
DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED);
}
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
const cudaStream_t& stream, const int& smem_size,
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim) {
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) {
if (smem_size > 0)
DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size));
@@ -142,19 +186,29 @@ static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
config.blockDimZ = block_dim.z;
config.sharedMemBytes = smem_size;
config.hStream = stream;
config.numAttrs = 0;
config.attrs = nullptr;
// Create attributes
// NOTES: must use `static` or the `attr` will be deconstructed
static LaunchAttrHandle attr;
static LaunchAttrHandle attrs[2];
config.numAttrs = 0;
config.attrs = attrs;
// Cluster size
if (cluster_dim > 1) {
auto& attr = attrs[config.numAttrs ++];
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
attr.value.clusterDim.x = cluster_dim;
attr.value.clusterDim.x = static_cast<unsigned>(cluster_dim);
attr.value.clusterDim.y = 1;
attr.value.clusterDim.z = 1;
config.attrs = &attr;
config.numAttrs = 1;
}
// Dependent kernel launch
if (enable_pdl) {
auto& attr = attrs[config.numAttrs ++];
attr.id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION;
attr.value.programmaticStreamSerializationAllowed = 1;
}
return config;
}

View File

@@ -0,0 +1,80 @@
#pragma once
#include <filesystem>
#include <regex>
#include <string>
#include <vector>
#include "../utils/format.hpp"
#include "../utils/system.hpp"
namespace deep_gemm {
class IncludeParser {
std::unordered_map<std::string, std::optional<std::string>> cache;
static std::vector<std::string> get_includes(const std::string& code, const std::filesystem::path& file_path = "") {
std::vector<std::string> includes;
const std::regex pattern(R"(#\s*include\s*[<"][^>"]+[>"])");
std::sregex_iterator iter(code.begin(), code.end(), pattern);
const std::sregex_iterator end;
// TODO: parse relative paths as well
for (; iter != end; ++ iter) {
const auto include_str = iter->str();
const int len = include_str.length();
if (include_str.substr(0, 10) == "#include <" and include_str[len - 1] == '>' and include_str[10] != ' ' and include_str[len - 2] != ' ') {
std::string filename = include_str.substr(10, len - 11);
if (filename.substr(0, 9) == "deep_gemm") // We only parse `<deep_gemm/*>`
includes.push_back(filename);
} else {
std::string error_info = fmt::format("Non-standard include: {}", include_str);
if (file_path != "")
error_info += fmt::format(" ({})", file_path.string());
DG_HOST_UNREACHABLE(error_info);
}
}
return includes;
}
public:
static std::filesystem::path library_include_path;
static void prepare_init(const std::string& library_root_path) {
library_include_path = std::filesystem::path(library_root_path) / "include";
}
std::string get_hash_value(const std::string& code, const bool& exclude_code = true) {
std::stringstream ss;
for (const auto& i: get_includes(code))
ss << get_hash_value_by_path(library_include_path / i) << "$";
if (not exclude_code)
ss << "#" << get_hex_digest(code);
return get_hex_digest(ss.str());
}
std::string get_hash_value_by_path(const std::filesystem::path& path) {
// Check whether hit in cache
// ReSharper disable once CppUseAssociativeContains
if (cache.count(path) > 0) {
const auto opt = cache[path];
if (not opt.has_value())
DG_HOST_UNREACHABLE(fmt::format("Circular include may occur: {}", path.string()));
return opt.value();
}
// Read file and calculate hash recursively
std::ifstream in(path);
if (not in.is_open())
DG_HOST_UNREACHABLE(fmt::format("Failed to open: {}", path.string()));
std::string code((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
cache[path] = std::nullopt;
return (cache[path] = get_hash_value(code, false)).value();
}
};
DG_DECLARE_STATIC_VAR_IN_CLASS(IncludeParser, library_include_path);
static auto include_parser = std::make_shared<IncludeParser>();
} // namespace deep_gemm

View File

@@ -1,10 +1,13 @@
#pragma once
#include <chrono>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/system.hpp"
#include "device_runtime.hpp"
#include "handle.hpp"
#include "include_parser.hpp"
namespace deep_gemm {
@@ -13,12 +16,13 @@ struct LaunchArgs {
int num_threads;
int smem_size;
int cluster_dim;
bool enable_pdl;
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true):
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1):
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim) {}
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true):
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
};
class KernelRuntime final {
@@ -33,36 +37,56 @@ public:
DG_HOST_ASSERT(not cuda_home.empty());
// NOLINT(*-pro-type-member-init)
const auto& cuobjdump_path = cuda_home / "bin" / "cuobjdump";
const auto& cubin_path = dir_path / "kernel.cubin";
const auto cuobjdump_path = cuda_home / "bin" / "cuobjdump";
const auto cubin_path = dir_path / "kernel.cubin";
if (get_env<int>("DG_JIT_DEBUG"))
printf("Loading CUBIN: %s\n", cubin_path.c_str());
// Record start time
std::chrono::high_resolution_clock::time_point start_time;
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_JIT_PRINT_LOAD_TIME"))
start_time = std::chrono::high_resolution_clock::now();
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
// Load from the library
kernel = load_kernel(cubin_path, {}, &library);
#else
// Find the only symbol
// TODO: use kernel enumeration for newer drivers
const std::vector<std::string> illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"};
const auto& [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str()));
const auto [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str()));
DG_HOST_ASSERT(exit_code == 0);
std::istringstream iss(symbols);
std::vector<std::string> symbol_names;
for (std::string line; std::getline(iss, line); ) {
if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and
std::none_of(illegal_names.begin(), illegal_names.end(),
[&](const auto& name) { return line.find(name) != std::string::npos; })) {
const auto& last_space = line.rfind(' ');
[&](const auto name) { return line.find(name) != std::string::npos; })) {
const auto last_space = line.rfind(' ');
symbol_names.push_back(line.substr(last_space + 1));
}
}
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Symbol names: ");
// Print symbols
if (symbol_names.size() != 1 or get_env<int>("DG_JIT_DEBUG")) {
printf("Symbols: ");
printf(" > CUBIN: %s\n", cubin_path.c_str());
printf(" > Raw symbols: %s\n", symbols.c_str());
printf(" > Parsed symbols:\n");
for (const auto& symbol: symbol_names)
printf("%s, ", symbol.c_str());
printf("\n");
printf(" > %s, ", symbol.c_str());
}
DG_HOST_ASSERT(symbol_names.size() == 1);
// Load from the library
DG_HOST_ASSERT(symbol_names.size() == 1);
kernel = load_kernel(cubin_path, symbol_names[0], &library);
#endif
// Print load time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_JIT_PRINT_LOAD_TIME")) {
std::chrono::duration<double, std::milli> load_time = std::chrono::high_resolution_clock::now() - start_time;
printf("Load time (%s): %.2lf ms\n", dir_path.c_str(), load_time.count());
}
}
static void prepare_init(const std::string& cuda_home_path_by_python) {
@@ -70,8 +94,19 @@ public:
}
static bool check_validity(const std::filesystem::path& dir_path) {
return std::filesystem::exists(dir_path / "kernel.cu") and
std::filesystem::exists(dir_path / "kernel.cubin");
if (not std::filesystem::exists(dir_path))
return false;
// NOTES: if the directory exists, `kernel.cu` and `kernel.cubin` must both exist,
// because the directory is created atomically via rename
if (not std::filesystem::exists(dir_path / "kernel.cu") or
not std::filesystem::exists(dir_path / "kernel.cubin")) {
printf("Corrupted JIT cache directory (missing kernel.cu or kernel.cubin): %s, "
"please run `rm -rf %s` and restart your task.\n",
dir_path.c_str(), dir_path.c_str());
DG_HOST_ASSERT(false and "Corrupted JIT cache directory");
}
return true;
}
~KernelRuntime() noexcept(false) {
@@ -86,30 +121,42 @@ class LaunchRuntime {
public:
template <typename Args>
static std::string generate(const Args& args) {
const auto& code = Derived::generate_impl(args);
if (get_env<int>("DG_JIT_DEBUG", 0))
printf("Generated kernel code: %s\n", code.c_str());
auto code = Derived::generate_impl(args);
// NOTES: we require that `generate_impl`'s includes never change
static std::string include_hash;
if (include_hash.empty())
include_hash = include_parser->get_hash_value(code);
// TODO: optimize string concat performance
code = fmt::format("// Includes' hash value: {}\n{}", include_hash, code);
if (get_env<int>("DG_JIT_DEBUG"))
printf("Generated kernel code:\n%s\n", code.c_str());
return code;
}
template <typename Args>
static void launch(const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) {
const auto& kernel = kernel_runtime->kernel;
const auto& stream = at::cuda::getCurrentCUDAStream();
const LaunchArgs& launch_args = args.launch_args;
const auto kernel = kernel_runtime->kernel;
const auto stream = at::cuda::getCurrentCUDAStream();
LaunchArgs launch_args = args.launch_args;
const dim3& grid_dim = {static_cast<unsigned>(launch_args.grid_dim.first),
static_cast<unsigned>(launch_args.grid_dim.second),
1};
const dim3& block_dim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
// Allow runtime override from Python.
// NOTES: the default is enabled.
launch_args.enable_pdl = device_runtime->get_pdl();
const dim3 grid_dim = {static_cast<unsigned>(launch_args.grid_dim.first),
static_cast<unsigned>(launch_args.grid_dim.second),
1};
const dim3 block_dim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
auto config = construct_launch_config(kernel, stream, launch_args.smem_size,
grid_dim, block_dim, launch_args.cluster_dim);
grid_dim, block_dim, launch_args.cluster_dim, launch_args.enable_pdl);
// Launch in the derived class
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, stream: %ld\n",
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, pdl: %d, stream: %ld\n",
launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads,
launch_args.smem_size, launch_args.cluster_dim, stream.id());
launch_args.smem_size, launch_args.cluster_dim, launch_args.enable_pdl, stream.id());
}
Derived::launch_impl(kernel, config, args);
}

View File

@@ -1,339 +1,54 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include <unordered_set>
#include <deep_gemm/common/types.cuh>
#include "../../utils/math.hpp"
#include "config.hpp"
#include "runtime.hpp"
#include "../../utils/layout.hpp"
#include "../../utils/system.hpp"
namespace deep_gemm {
struct MulticastConfig {
int num_multicast;
bool is_multicast_on_a;
MulticastConfig(const int& num_multicast, const bool& is_multicast_on_a):
num_multicast(num_multicast), is_multicast_on_a(is_multicast_on_a) {
DG_HOST_ASSERT(1 <= num_multicast and num_multicast <= 2);
}
};
struct SharedMemoryConfig {
int smem_size;
int swizzle_a_mode;
int swizzle_b_mode;
int swizzle_cd_mode;
};
struct ThreadConfig {
int num_threads;
// SM90
int num_tma_threads;
int num_math_threads;
// SM100
int num_non_epilogue_threads;
int num_epilogue_threads;
static ThreadConfig sm90(const int& num_tma_threads,
const int& num_math_threads) {
auto config = ThreadConfig();
config.num_threads = num_tma_threads + num_math_threads;
config.num_tma_threads = num_tma_threads;
config.num_math_threads = num_math_threads;
return config;
}
static ThreadConfig sm100(const int& num_non_epilogue_threads,
const int& num_epilogue_threads) {
auto config = ThreadConfig();
config.num_threads = num_non_epilogue_threads + num_epilogue_threads;
config.num_non_epilogue_threads = num_non_epilogue_threads;
config.num_epilogue_threads = num_epilogue_threads;
return config;
}
};
struct GemmConfig {
// Templated configs
GemmType gemm_type;
KernelType kernel_type;
MmaKind mma_kind;
at::ScalarType a_dtype, b_dtype, cd_dtype;
cute::UMMA::Major major_a;
cute::UMMA::Major major_b;
bool with_accumulation;
int block_m, block_n, block_k;
int num_stages, num_last_stages;
// Templated device configs
int num_sms;
int tc_util;
// Structured configs
MulticastConfig multicast_config;
SharedMemoryConfig smem_config;
ThreadConfig thread_config;
};
static bool is_multicast_legal(const int& shape_dim, const int& block_dim,
const int& num_multicast, const int& num_sms,
const bool& require_divisible) {
const bool& divisible = ceil_div(shape_dim, block_dim) % num_multicast == 0 or not require_divisible;
return divisible and num_sms % num_multicast == 0;
}
template <typename size_type_t>
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
// `> 0` means interleaving
// 16B actually means non-swizzling (but interleaving)
for (const int& mode: {128, 64, 32, 16}) {
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
return mode;
}
DG_HOST_UNREACHABLE("Unreachable");
}
template <typename ArchSpec>
static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& num_stages, const MulticastConfig& multicast_config) {
const int& ab_elem_size = static_cast<int>(get_element_size(mma_kind));
const int& cd_elem_size = static_cast<int>(c10::elementSize(cd_dtype));
static GemmConfig get_best_config(const GemmDesc& desc) {
desc.check_validity();
const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m);
const int& load_block_n = ArchSpec::get_ab_load_block_n(multicast_config, block_n);
const int& swizzle_a_mode = get_swizzle_mode(major_a == cute::UMMA::Major::K ? block_k : load_block_m, ab_elem_size);
const int& swizzle_b_mode = get_swizzle_mode(major_b == cute::UMMA::Major::K ? block_k : load_block_n, ab_elem_size);
const int& swizzle_cd_mode = ArchSpec::enable_cd_swizzle(cd_dtype) ? get_swizzle_mode(block_n, cd_elem_size) : 0;
// Different archs have different epilogue pipelines
const int& smem_cd = ArchSpec::get_smem_cd_size(kernel_type, block_m, block_n, swizzle_cd_mode, cd_dtype);
// A/B shared memory
const int& smem_a_per_stage = load_block_m * block_k * ab_elem_size;
const int& smem_b_per_stage = load_block_n * block_k * ab_elem_size;
// SF shared memory
const auto& [smem_sfa_per_stage, smem_sfb_per_stage] =
ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype);
const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k);
// M-barriers and tensor memory pointers
const int& smem_barrier = ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = ArchSpec::get_tmem_ptr_smem_size();
const int& smem_tensor_map = ArchSpec::get_tensormap_smem_size(gemm_type);
// Sum them up
int smem_size = 0;
smem_size += smem_tensor_map;
smem_size += smem_cd;
smem_size += num_stages * smem_a_per_stage;
smem_size += num_stages * smem_b_per_stage;
smem_size += num_stages * smem_sfa_per_stage;
smem_size += num_stages * smem_sfb_per_stage;
smem_size += smem_extra_sfb;
smem_size += smem_barrier;
smem_size += smem_tmem_ptr;
return SharedMemoryConfig {
.smem_size = smem_size,
.swizzle_a_mode = swizzle_a_mode,
.swizzle_b_mode = swizzle_b_mode,
.swizzle_cd_mode = swizzle_cd_mode,
};
}
template <typename ArchSpec>
static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type,
const int& m, const int& n, const int& k, const int& num_groups,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const at::ScalarType& a_dtype, const at::ScalarType& b_dtype,
const at::ScalarType& cd_dtype,
const bool& with_accumulation, const int& num_sms) {
const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4);
if (mma_kind == MmaKind::BF16) {
DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16);
} else {
DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4);
DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4);
}
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
// Select M/N block sizes
auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m);
if (gemm_type == GemmType::MGroupedContiguous)
block_ms = std::vector{get_mk_alignment_for_contiguous_layout()};
if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout)
block_ms = std::vector{64, 128}; // Exclude 256 for performance
auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype);
// NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B
// TODO: Optimize it
if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN)
block_ms = std::vector{128};
if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN)
block_ns = std::vector{128};
// K block size is selected in a fixed manner
const auto& block_k = (mma_kind == MmaKind::BF16 ? 64 : 128);
// Some util functions
const auto& get_num_blocks = [=](const int& block_m, const int& block_n) {
return ceil_div(m, block_m) * ceil_div(n, block_n) * num_groups;
};
const auto& get_num_waves = [=](const int& block_m, const int& block_n) {
return ceil_div(get_num_blocks(block_m, block_n), num_sms);
};
const auto& get_last_wave_util = [=](const int& block_m, const int& block_n) {
const auto& num_last_blocks = get_num_blocks(block_m, block_n) % num_sms;
return num_last_blocks == 0 ? num_sms : num_last_blocks;
};
// Decide block sizes by waves
int best_block_m = 0, best_block_n = 0;
int best_num_waves = 0, best_last_util = 0;
for (const auto& block_m: block_ms) {
for (const auto& block_n: block_ns) {
const int& num_waves = get_num_waves(block_m, block_n);
const auto& last_util = get_last_wave_util(block_m, block_n);
if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, mma_kind, cd_dtype, m, n, k, block_m, block_n, block_k))
continue;
bool success = false;
if (best_block_m == 0 or best_block_n == 0 or num_waves < best_num_waves) {
success = true;
} else if (num_waves == best_num_waves) {
// Check last wave utilization
success = last_util > best_last_util;
if (last_util == best_last_util) {
// Case 1: same `block_m`, smaller `block_n` (wasted)
success |= block_m == best_block_m and block_n < best_block_n;
// Case 2: same `block_n`, smaller `block_m` (wasted)
success |= block_n == best_block_n and block_m < best_block_m;
// Case 3: different for both `block_m` and `block_n`, larger `block_n` is better
// NOTES: don't pick `block_m/block_n` larger than shape `m/n` in this case
success |= block_m != best_block_m and block_n > best_block_n
and block_n <= n and block_m <= m;
}
}
// Replace with the new config if successful
if (success) {
best_block_m = block_m, best_block_n = block_n;
best_num_waves = num_waves, best_last_util = last_util;
}
}
}
DG_HOST_ASSERT(best_block_m > 0 and best_block_n > 0);
// Decide the number of TMA multicasts and whether broadcast on A
MulticastConfig best_multicast_config = {1, false};
auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality(
gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms);
// NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B
// TODO: Optimize it
if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN)
is_legal_on_a = false;
if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN)
is_legal_on_b = false;
const bool is_legal[2] = {is_legal_on_b, is_legal_on_a};
bool order[2] = {false, true};
if (best_block_m > best_block_n)
std::swap(order[0], order[1]);
for (const bool& is_multicast_on_a: order) {
if (m >= 512 and is_legal[static_cast<int>(is_multicast_on_a)]) {
best_multicast_config = {2, is_multicast_on_a};
break;
}
// Choose the best layout
const auto layout_candidates = ArchSpec::get_layout_candidates(desc);
DG_HOST_ASSERT(not layout_candidates.empty());
auto layout = layout_candidates[0];
auto layout_info = ArchSpec::get_layout_info(desc, layout);
for (int i = 1; i < static_cast<int>(layout_candidates.size()); ++ i) {
const auto candidate_info = ArchSpec::get_layout_info(desc, layout_candidates[i]);
if (ArchSpec::compare(candidate_info, layout_info))
layout = layout_candidates[i], layout_info = candidate_info;
}
// Always pick the largest number of stage
constexpr int smem_capacity = ArchSpec::smem_capacity;
int best_num_stages = 0;
SharedMemoryConfig best_smem_config;
for (int num_stages = 32; num_stages > 0; -- num_stages) {
if (not ArchSpec::is_num_stages_legal(mma_kind, cd_dtype, num_stages, best_block_m, best_block_n, block_k))
continue;
best_smem_config = get_smem_config<ArchSpec>(gemm_type, kernel_type,
m, n, k,
best_block_m, best_block_n, block_k,
major_a, major_b,
mma_kind, cd_dtype,
num_stages, best_multicast_config);
if (best_smem_config.smem_size <= smem_capacity) {
best_num_stages = num_stages;
break;
}
}
DG_HOST_ASSERT(best_num_stages != 0);
// Recompute the minimal number of SMs required
// NOTES: less L2 cache usage and less GPU frequency drop
int num_min_sms = num_sms;
if (get_env<int>("DG_JIT_MINIMIZE_NUM_SMS", 0)) {
num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves);
num_min_sms = align(num_min_sms, best_multicast_config.num_multicast);
DG_HOST_ASSERT(num_min_sms <= num_sms);
}
const auto& config = GemmConfig {
.gemm_type = gemm_type,
.kernel_type = kernel_type,
.mma_kind = mma_kind,
.a_dtype = a_dtype,
.b_dtype = b_dtype,
.cd_dtype = cd_dtype,
.major_a = major_a,
.major_b = major_b,
.with_accumulation = with_accumulation,
.block_m = best_block_m,
.block_n = best_block_n,
.block_k = block_k,
.num_stages = best_num_stages,
.num_last_stages = ceil_div(k, block_k) % best_num_stages,
.num_sms = num_min_sms,
.tc_util = device_runtime->get_tc_util(),
.multicast_config = best_multicast_config,
// ReSharper disable once CppLocalVariableMightNotBeInitialized
.smem_config = best_smem_config,
.thread_config = ArchSpec::get_thread_config(kernel_type, best_block_m, best_block_n)
// Infer other configs
const auto storage_config = ArchSpec::get_storage_config(desc, layout);
const auto pipeline_config = ArchSpec::get_pipeline_config(desc, layout, storage_config);
const auto launch_config = ArchSpec::get_launch_config(desc, layout);
const auto gemm_config = GemmConfig {
.layout = layout,
.storage_config = storage_config,
.pipeline_config = pipeline_config,
.launch_config = launch_config
};
// Only SM100 BF16 kernels support tensor core control
if (config.tc_util < 100)
DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and mma_kind == MmaKind::BF16);
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b,
mma_kind, a_dtype, b_dtype, cd_dtype, with_accumulation, num_sms);
static std::set<decltype(key)> printed;
std::stringstream ss;
ss << desc;
const auto key = ss.str();
static std::unordered_set<std::string> printed;
if (printed.count(key) == 0) {
printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, "
"A major: %d, B major: %d, MMA kind: %d, A dtype: %s, B dtype: %s, CD dtype: %s, accumulation: %d, "
"SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, "
"SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, "
"swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n",
static_cast<int>(gemm_type), static_cast<int>(kernel_type), m, n, k, num_groups,
static_cast<int>(major_a), static_cast<int>(major_b), static_cast<int>(mma_kind),
c10::toString(a_dtype), c10::toString(b_dtype), c10::toString(cd_dtype),
static_cast<int>(with_accumulation), num_sms, best_block_m, best_block_n, block_k,
best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast,
static_cast<int>(best_multicast_config.is_multicast_on_a),
best_smem_config.smem_size, best_smem_config.swizzle_a_mode, best_smem_config.swizzle_b_mode,
best_smem_config.swizzle_cd_mode, config.num_sms, config.thread_config.num_threads, config.tc_util);
std::cout << desc << ": " << gemm_config << ", " << layout_info << std::endl;
printed.insert(key);
}
}
return config;
return gemm_config;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,171 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
#include <c10/core/ScalarType.h>
#include <deep_gemm/common/types.cuh>
#include "../../utils/math.hpp"
namespace deep_gemm {
/// GEMM descriptors
struct GemmDesc {
GemmType gemm_type;
KernelType kernel_type;
int m, n, k, num_groups;
at::ScalarType a_dtype, b_dtype, cd_dtype;
cute::UMMA::Major major_a;
cute::UMMA::Major major_b;
bool with_accumulation;
// Requirements from users
int num_sms, tc_util;
std::string compiled_dims;
// Shape for heuristic generation
int expected_m = 0, expected_n = 0, expected_k = 0, expected_num_groups = 0;
int get_expected_m() const { return expected_m > 0 ? expected_m : m; }
int get_expected_n() const { return expected_n > 0 ? expected_n : n; }
int get_expected_k() const { return expected_k > 0 ? expected_k : k; }
int get_expected_num_groups() const { return expected_num_groups > 0 ? expected_num_groups : num_groups; }
MmaKind get_mma_kind() const {
return a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4;
}
void check_validity() const {
if (get_mma_kind() == MmaKind::BF16) {
DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16);
} else {
DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4);
DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4);
}
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
DG_HOST_ASSERT(num_sms % 2 == 0);
}
friend std::ostream& operator << (std::ostream& os, const GemmDesc& desc) {
MmaKind mma_kind = desc.get_mma_kind();
os << "GemmDesc(gemm_type=" << static_cast<int>(desc.gemm_type)
<< ", kernel_type=" << static_cast<int>(desc.kernel_type)
<< ", m=" << desc.m << ", n=" << desc.n << ", k=" << desc.k
<< ", num_groups=" << desc.num_groups
<< ", major_a=" << static_cast<int>(desc.major_a)
<< ", major_b=" << static_cast<int>(desc.major_b)
<< ", mma_kind=" << static_cast<int>(mma_kind)
<< ", a_dtype=" << c10::toString(desc.a_dtype)
<< ", b_dtype=" << c10::toString(desc.b_dtype)
<< ", cd_dtype=" << c10::toString(desc.cd_dtype)
<< ", with_accumulation=" << static_cast<int>(desc.with_accumulation)
<< ", num_sms=" << desc.num_sms
<< ", tc_util=" << desc.tc_util
<< ", compiled_dims=" << desc.compiled_dims
<< ", expected_m=" << desc.expected_m
<< ", expected_n=" << desc.expected_n
<< ", expected_k=" << desc.expected_k
<< ", expected_num_groups=" << desc.expected_num_groups << ")";
return os;
}
};
/// GEMM configs
struct Layout {
int swap_ab;
int block_m, block_n, block_k;
int cluster_m, cluster_n;
int get_cluster_size() const {
return cluster_m * cluster_n;
}
friend std::ostream& operator << (std::ostream& os, const Layout& layout) {
os << "Layout(swap_ab=" << layout.swap_ab
<< ", block_m=" << layout.block_m << ", block_n=" << layout.block_n << ", block_k=" << layout.block_k
<< ", cluster_m=" << layout.cluster_m << ", cluster_n=" << layout.cluster_n << ")";
return os;
}
};
struct StorageConfig {
int load_block_m, load_block_n;
int store_block_m, store_block_n;
int swizzle_a_mode, swizzle_b_mode;
int swizzle_cd_mode;
friend std::ostream& operator << (std::ostream& os, const StorageConfig& config) {
os << "StorageConfig("
<< "load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n
<< ", store_block_m=" << config.store_block_m << ", store_block_n=" << config.store_block_n
<< ", swizzle_a_mode=" << config.swizzle_a_mode << ", swizzle_b_mode=" << config.swizzle_b_mode
<< ", swizzle_cd_mode=" << config.swizzle_cd_mode << ")";
return os;
}
};
struct PipelineConfig {
int smem_size;
int num_stages;
friend std::ostream& operator << (std::ostream& os, const PipelineConfig& config) {
os << "PipelineConfig("
<< "smem_size=" << config.smem_size
<< ", num_stages=" << config.num_stages << ")";
return os;
}
};
struct LaunchConfig {
int num_sms;
int num_sms_per_cluster;
int num_threads;
int num_tma_threads;
int num_math_threads;
int num_non_epilogue_threads;
int num_epilogue_threads;
friend std::ostream& operator << (std::ostream& os, const LaunchConfig& config) {
os << "LaunchConfig("
<< "num_sms=" << config.num_sms << ", num_sms_per_cluster=" << config.num_sms_per_cluster
<< ", num_threads=" << config.num_threads
<< ", num_tma_threads=" << config.num_tma_threads << ", num_math_threads=" << config.num_math_threads
<< ", num_non_epilogue_threads=" << config.num_non_epilogue_threads
<< ", num_epilogue_threads=" << config.num_epilogue_threads << ")";
return os;
}
};
struct GemmConfig {
Layout layout;
StorageConfig storage_config;
PipelineConfig pipeline_config;
LaunchConfig launch_config;
friend std::ostream& operator << (std::ostream& os, const GemmConfig& config) {
os << "GemmConfig("
<< "layout=" << config.layout
<< ", storage_config=" << config.storage_config
<< ", pipeline_config=" << config.pipeline_config
<< ", launch_config=" << config.launch_config << ")";
return os;
}
};
/// Config comparators
struct LayoutInfo {
int num_waves;
int last_wave_util;
int64_t num_cycles;
Layout layout;
friend std::ostream& operator << (std::ostream& os, const LayoutInfo& config) {
os << "LayoutInfo("
<< "num_waves=" << config.num_waves
<< ", last_wave_util=" << config.last_wave_util
<< ", num_cycles=" << config.num_cycles << ")";
return os;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,211 @@
#pragma once
#include <algorithm>
#include <unordered_set>
#include <deep_gemm/layout/mega_moe.cuh>
#include "../../utils/exception.hpp"
#include "../../utils/math.hpp"
#include "../../utils/system.hpp"
#include "sm100.hpp"
namespace deep_gemm {
struct MegaMoEConfig {
// Block tiling
int block_m, block_n, block_k;
int load_block_m, load_block_n;
int store_block_m;
// SF block sizes (UTCCP 128-aligned)
int sf_block_m, sf_block_n;
// Pool capacity and SF-padded token count
int num_max_pool_tokens;
int num_padded_sf_pool_tokens;
// Swizzle modes for TMA descriptors
int swizzle_acts_mode, swizzle_weights_mode;
// Number of experts to process per wave
int num_experts_per_wave;
// Pipeline stages and shared memory
int num_stages, smem_size;
// Thread layout
int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads;
friend std::ostream& operator << (std::ostream& os, const MegaMoEConfig& config) {
os << "MegaMoEConfig("
<< "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k
<< ", load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n
<< ", store_block_m=" << config.store_block_m
<< ", sf_block_m=" << config.sf_block_m << ", sf_block_n=" << config.sf_block_n
<< ", num_max_pool_tokens=" << config.num_max_pool_tokens
<< ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens
<< ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode
<< ", num_experts_per_wave=" << config.num_experts_per_wave
<< ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size
<< ", num_dispatch_threads=" << config.num_dispatch_threads
<< ", num_non_epilogue_threads=" << config.num_non_epilogue_threads
<< ", num_epilogue_threads=" << config.num_epilogue_threads << ")";
return os;
}
};
static int get_block_m_for_mega_moe(const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk) {
// TODO: compute based on configs
return 192;
}
static int get_num_experts_per_wave_for_mega_moe(
const int& num_experts_per_rank, const int& num_tokens, const int& num_topk,
const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) {
// Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens
constexpr int kImbalanceFactor = 2;
// TODO: support num_experts_per_rank > 32
// Find the largest divisor of num_experts_per_rank that fits in 32 as the upper bound
int max_num_experts_per_wave = std::min(32, num_experts_per_rank);
while (max_num_experts_per_wave > 1 and num_experts_per_rank % max_num_experts_per_wave != 0)
-- max_num_experts_per_wave;
// Count L1 blocks per expert assuming tokens are evenly spread across experts
const int expected_tokens_per_expert =
num_tokens * num_topk / num_experts_per_rank + 1;
const int num_m_blocks = ceil_div(expected_tokens_per_expert, block_m);
const int num_n_blocks = intermediate_hidden / block_n;
const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks;
// Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy
int num_experts_per_wave = num_l1_blocks_per_expert > 0
? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1;
num_experts_per_wave = std::min(num_experts_per_wave, max_num_experts_per_wave);
// Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count
while (num_experts_per_wave < max_num_experts_per_wave and num_experts_per_rank % num_experts_per_wave != 0)
++ num_experts_per_wave;
return num_experts_per_wave;
}
static std::pair<int, int> get_pipeline_config_for_mega_moe(
const int& smem_capacity,
const int& num_experts, const int& hidden,
const int& block_m, const int& block_n, const int& block_k, const int& store_block_m,
const int& sf_block_m, const int& sf_block_n,
const int& num_dispatch_warps, const int& num_epilogue_warps) {
constexpr int kSmemAlignment = 1024;
constexpr int kNumEpilogueStages = 2;
constexpr int kNumTMAStoreStages = 2;
// Always multicast on A
const int load_block_m = block_m / 2;
// Dispatch region
const int smem_expert_count_size = align(
num_experts * static_cast<int>(sizeof(uint32_t)), kSmemAlignment);
const int smem_send_buffers_size = align(
static_cast<int>(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()),
kSmemAlignment);
const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size;
// C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage)
const auto num_epilogue_warpgroups = num_epilogue_warps / 4;
const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages;
const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast<int>(sizeof(nv_bfloat16));
const int smem_cd = std::max(smem_cd_l1, smem_cd_l2);
// Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp)
const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8;
// Amax reduction
const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast<int>(sizeof(float));
// Tensor memory pointer
const int smem_tmem_ptr = 4;
// SF is aligned to UTCCP 128-element granularity
const int smem_sfa_per_stage = sf_block_m * 4;
const int smem_sfb_per_stage = sf_block_n * 4;
// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers
const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;
// Fixed total
const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr;
// Select maximum num_stages
const int num_stages = (smem_capacity - smem_fixed) / smem_per_stage;
DG_HOST_ASSERT(num_stages >= 2);
return {num_stages, smem_fixed + num_stages * smem_per_stage};
}
static MegaMoEConfig get_mega_moe_config(
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden) {
// Block tiling
const int block_m = get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
const int block_n = 128;
const int block_k = 128;
const int load_block_m = block_m / 2;
const int load_block_n = block_n;
const int store_block_m = 32;
const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4);
const int num_max_pool_tokens = layout::get_num_max_pool_tokens(
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
const int num_padded_sf_pool_tokens = layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m);
// NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle
const int swizzle_acts_mode = 128;
const int swizzle_weights_mode = 128;
// Waves
const int num_sms = device_runtime->get_num_sms();
const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe(
num_experts_per_rank, num_tokens, num_topk,
intermediate_hidden, block_m, block_n, num_sms);
// Thread layout
const int num_dispatch_threads = 128;
const int num_non_epilogue_threads = 128;
const int num_epilogue_threads = 256;
// Pipeline
const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe(
SM100ArchSpec::smem_capacity,
num_experts, hidden,
block_m, block_n, block_k, store_block_m,
sf_block_m, sf_block_n,
num_dispatch_threads / 32, num_epilogue_threads / 32);
const auto config = MegaMoEConfig {
block_m, block_n, block_k,
load_block_m, load_block_n, store_block_m,
sf_block_m, sf_block_n,
num_max_pool_tokens, num_padded_sf_pool_tokens,
swizzle_acts_mode, swizzle_weights_mode,
num_experts_per_wave,
num_stages, smem_size,
num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads
};
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
const auto key = fmt::format(
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk);
static std::unordered_set<std::string> printed;
if (printed.count(key) == 0) {
std::cout << key << ": " << config << std::endl;
printed.insert(key);
}
}
return config;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,62 @@
#pragma once
#include "../../jit/device_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/lazy_init.hpp"
namespace deep_gemm {
class HeuristicsRuntime {
static constexpr int kLegacyMKAlignmentForContiguousLayout = 128;
bool ignore_compile_dims = false;
int block_m_multiple_of = 1;
int block_n_multiple_of = 1;
int mk_alignment_for_contiguous_layout = kLegacyMKAlignmentForContiguousLayout;
public:
void set_ignore_compile_dims(const bool& new_value) {
ignore_compile_dims = new_value;
}
bool get_ignore_compile_dims() const {
return ignore_compile_dims;
}
void set_block_size_multiple_of(const int& new_block_m_multiple_of, const int& new_block_n_multiple_of) {
block_m_multiple_of = new_block_m_multiple_of;
block_n_multiple_of = new_block_n_multiple_of;
}
int get_block_m_multiple_of() const {
return block_m_multiple_of;
}
int get_block_n_multiple_of() const {
return block_n_multiple_of;
}
void set_mk_alignment_for_contiguous_layout(const int& new_value) {
mk_alignment_for_contiguous_layout = new_value;
}
int get_mk_alignment_for_contiguous_layout() const {
return mk_alignment_for_contiguous_layout;
}
static int get_theoretical_mk_alignment_for_contiguous_layout(const std::optional<int>& expected_m) {
if (device_runtime->get_arch_major() != 10)
return kLegacyMKAlignmentForContiguousLayout;
int block_m = 240, mma_step = 16;
if (expected_m.has_value()) {
// Reduce `block_m` while ensuring it covers `m`
for (; block_m > 32 and block_m - mma_step >= expected_m.value(); block_m -= mma_step);
}
return block_m;
}
};
static auto heuristics_runtime = LazyInit<HeuristicsRuntime>([](){ return std::make_shared<HeuristicsRuntime>(); });
} // namespace deep_gemm

View File

@@ -2,9 +2,11 @@
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.hpp>
#include <deep_gemm/common/types.cuh>
#include "common.hpp"
#include "runtime.hpp"
#include "utils.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
@@ -12,155 +14,255 @@ namespace deep_gemm {
struct SM100ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<int> get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) {
std::vector<int> candidates{128, 256};
if ((kernel_type == KernelType::Kernel1D1D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) {
// NOTES: `block_m = 32/64` is smaller than `LAYOUT_AD_M`, should be careful in handling this
if (m <= 32) candidates.push_back(32);
if (m <= 64) candidates.push_back(64);
}
return candidates;
}
static std::vector<int> get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) {
// 16 is for better SM usage
// Stride 32 is due to low-performance swizzle-16/32B
std::vector<int> candidates = {16};
for (int i = 32; i <= 256; i += 32)
candidates.push_back(i);
return candidates;
}
static int get_ab_load_block_m(const MulticastConfig& config, const int& block_m) {
return block_m / (config.is_multicast_on_a ? config.num_multicast : 1);
}
static int get_ab_load_block_n(const MulticastConfig& config, const int& block_n) {
return block_n / (config.is_multicast_on_a ? 1 : config.num_multicast);
}
static int get_cd_store_block_m(const int& block_m) {
constexpr int layout_ad_m = 128;
return std::min(block_m, layout_ad_m);
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
return true;
}
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
const int& block_m, const int& block_n, const MmaKind& mma_kind) {
constexpr int num_utccp_aligned_elems = 128;
switch (mma_kind) {
case MmaKind::BF16: return {0, 0};
case MmaKind::BF16: return {0, 0};
case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
default: DG_HOST_UNREACHABLE("Unknown dtype");
}
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
// Layout A/D does not support `block_n % 16 != 0`
if (block_n % 16 != 0)
return false;
static std::vector<Layout> get_layout_candidates(const GemmDesc& desc) {
// Block K is always in a fixed manner
const int block_k = 128 / get_element_size(desc.get_mma_kind());
// Performance is lower with 1D1D and `block_m == 256`
if (kernel_type == KernelType::Kernel1D1D and major_b == cute::UMMA::Major::K and block_m > 128)
return false;
// For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck
if (k <= 256 and (block_n > 128 or block_m > 128))
return false;
// Check tensor memory validity
int sf_block_m = 0, sf_block_n = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind);
sf_block_m = sf_block_m_, sf_block_n = sf_block_n_;
// Always enable swap A/B (and multicasting if possible) for m-grouped GEMMs
if (desc.gemm_type == GemmType::MGroupedContiguous or
desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout or
desc.gemm_type == GemmType::MGroupedMasked) {
const bool swap_ab = true;
const auto block_n = 128;
const auto block_m = heuristics_runtime->get_mk_alignment_for_contiguous_layout();
const auto cluster_m = 1;
const auto cluster_n = ceil_div(desc.n, block_n) % 2 == 0 and desc.num_sms % 2 == 0 ? 2 : 1;
const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n};
std::vector<Layout> candidates = {layout};
return candidates;
}
if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512)
return false;
// NOTES: when B is MN-major, we restrict `block_n` to multiples of 64,
// since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA
return major_b == cute::UMMA::Major::K or (block_n * get_element_size(mma_kind)) % 64 == 0;
// Enumerate all candidates
std::vector<Layout> candidates;
for (int swap_ab = 0; swap_ab < 2; ++ swap_ab) {
// Block M/N candidates
std::vector<int> block_m_candidates;
std::vector<int> block_n_candidates;
if (swap_ab) {
int step = std::lcm(16, heuristics_runtime->get_block_m_multiple_of());
int end = 256;
for (int i = step; i <= end; i += step)
block_m_candidates.push_back(i);
// TODO: consider other block N
block_n_candidates = {128};
} else {
// NOTES: smaller block M can avoid TMA L2 OOB bound
// TODO: consider block M = 256
if (desc.m <= 32) block_m_candidates = {32};
else if (desc.m <= 64) block_m_candidates = {64};
else block_m_candidates = {128};
// Small block size for small shape
if (16 % heuristics_runtime->get_block_n_multiple_of() == 0)
block_n_candidates.push_back(16);
int step = std::lcm(32, heuristics_runtime->get_block_n_multiple_of());
// For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck
int end = desc.k <= 256 ? 128 : 256;
for (int i = step; i <= end; i += step)
block_n_candidates.push_back(i);
}
for (int cluster_m = 1; cluster_m <= 2; ++ cluster_m) {
// After swapping, layout A/D can only do on cluster N
if (swap_ab == 1 and cluster_m > 1)
continue;
for (int cluster_n = 1; cluster_n <= 2; ++ cluster_n) {
// We only support cluster 2
if (cluster_m * cluster_n > 2)
continue;
// Only support layout A/D
if (swap_ab == 0 and cluster_n > 1)
continue;
// SM count must be divisible
if (desc.num_sms % (cluster_m * cluster_n) != 0)
continue;
for (int block_m: block_m_candidates) {
// Ensure large swizzle sizes (32B swizzle yields poor performance)
const auto swizzle_a_requirement = desc.a_dtype == kPackedFP4 ? 128 : 64;
// Enforce swizzle alignment for MN major; otherwise check base MMA shape
const auto load_block_m_requirement = desc.major_a == cute::UMMA::Major::MN ? swizzle_a_requirement : 8;
if ((block_m / cluster_n) % load_block_m_requirement != 0)
continue;
// Shape must be divisible for multicast
if (ceil_div(desc.m, block_m) % cluster_m != 0)
continue;
for (int block_n: block_n_candidates) {
// Ensure large swizzle sizes (32B swizzle yields poor performance)
const auto swizzle_b_requirement = desc.b_dtype == kPackedFP4 ? 128 : 64;
// Enforce swizzle alignment for MN major; otherwise check base MMA shape
const auto load_block_n_requirement = desc.major_b == cute::UMMA::Major::MN ? swizzle_b_requirement : 8;
if ((block_n / cluster_m) % load_block_n_requirement != 0)
continue;
// Shape must be divisible for multicast
if (ceil_div(desc.n, block_n) % cluster_n != 0)
continue;
// SwapAB requires block N is layout A/D' UMMA M
constexpr int layout_ad_m = 128;
if (swap_ab and block_n != layout_ad_m)
continue;
// Check tensor memory capacity
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, desc.get_mma_kind());
const auto tmem_sf_cols = desc.get_mma_kind() == MmaKind::MXFP8FP4 ? sf_block_m / 32 + sf_block_n / 32 : 0;
const auto umma_n = swap_ab ? block_m : block_n;
if (2 * umma_n + tmem_sf_cols > 512)
continue;
const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n};
// When neither A nor B is MN major, 128B swizzle is always feasible
if (desc.major_a == cute::UMMA::Major::K or desc.major_b == cute::UMMA::Major::K) {
const auto storage_config = get_storage_config(desc, layout);
if (storage_config.swizzle_a_mode != 128 or storage_config.swizzle_b_mode != 128)
continue;
}
candidates.push_back(layout);
}
}
}
}
}
DG_HOST_ASSERT(not candidates.empty());
return candidates;
}
static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
return true;
}
static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) {
constexpr int layout_ad_m = 128;
constexpr int umma_step_n = 16;
// Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms)
const auto load_block_m = layout.block_m / layout.cluster_n;
const auto load_block_n = layout.block_n / layout.cluster_m;
const auto store_block_m = layout.swap_ab ? umma_step_n : std::min(layout_ad_m, layout.block_m);
const auto store_block_n = layout.block_n;
// Decide swizzling by the inner dim
// TODO: support FP4 sub-byte
const auto swizzle_mode_a = get_swizzle_mode(
desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype));
const auto swizzle_mode_b = get_swizzle_mode(
desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype));
const auto swizzle_mode_cd = get_swizzle_mode(
store_block_n, c10::elementSize(desc.cd_dtype));
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// TODO: support other layouts
return {
false,
is_multicast_legal(m, block_m, 2, num_sms, true) and (gemm_type == GemmType::Normal or gemm_type == GemmType::KGroupedContiguous
or (gemm_type == GemmType::Batched and num_groups <= 32)),
load_block_m, load_block_n,
store_block_m, store_block_n,
swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd
};
}
static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm100(128, 128);
}
static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) {
constexpr int kNumMaxStages = 32;
static int get_smem_cd_size(const KernelType& kernel_type,
const int& block_m, const int& block_n,
const int& swizzle_cd_mode,
const at::ScalarType& cd_dtype) {
constexpr static int layout_ad_m = 128;
return std::min(block_m, layout_ad_m) * swizzle_cd_mode * 2;
}
// C/D for TMA stores
const int smem_cd = layout.swap_ab ? storage_config.store_block_m * storage_config.store_block_n * c10::elementSize(desc.cd_dtype) * 2
: storage_config.store_block_m * storage_config.swizzle_cd_mode * 2;
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype) {
if (mma_kind == MmaKind::BF16)
return {0, 0};
int smem_sfa_per_stage = 0;
int smem_sfb_per_stage = 0;
if (kernel_type == KernelType::Kernel1D1D) {
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind);
smem_sfa_per_stage = sf_block_m * 4;
smem_sfb_per_stage = sf_block_n * 4;
} else {
smem_sfa_per_stage = block_m * 4;
smem_sfb_per_stage = 0;
}
return {smem_sfa_per_stage, smem_sfb_per_stage};
}
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
return 0;
}
static int get_barrier_smem_size(const int& num_stages) {
// TODO: remove SF barriers for BF16 GEMMs
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
// NOTES: the last barrier is for tensor core utilization control
return num_stages * 8 * 3 + 2 * 8 * 2 + 8;
const int smem_barriers = kNumMaxStages * 8 * 3 + 2 * 8 * 2 + 8;
// Tensor memory pointer
const int smem_tmem_ptr = 4;
// Calculate A/B per stages
// TODO: consider FP4
const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype);
const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype);
// Calculate SF A/B per stages
int smem_sfa_per_stage = 0;
int smem_sfb_per_stage = 0;
if (desc.kernel_type == KernelType::Kernel1D1D) {
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(
layout.block_m, layout.block_n, desc.get_mma_kind());
smem_sfa_per_stage = sf_block_m * 4;
smem_sfb_per_stage = sf_block_n * 4;
}
// Calculate stages
int smem_extra = smem_cd + smem_barriers + smem_tmem_ptr;
int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage;
int num_stages = std::min(
(smem_capacity - smem_extra) / smem_per_stage,
kNumMaxStages);
return {
smem_extra + num_stages * smem_per_stage,
num_stages
};
}
static int get_tmem_ptr_smem_size() {
return 4;
static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) {
return {
desc.num_sms,
layout.get_cluster_size(),
256,
32, 128, 128, 128
};
}
static int get_tensormap_smem_size(const GemmType& gemm_type) {
return 0;
static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) {
const auto num_blocks =
ceil_div(desc.get_expected_m(), layout.block_m) *
ceil_div(desc.get_expected_n(), layout.block_n) *
desc.get_expected_num_groups();
const auto num_waves = ceil_div(num_blocks, desc.num_sms);
const auto num_last_blocks = num_blocks % desc.num_sms;
const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks;
// TODO: calculate expected cycles
return {num_waves, last_wave_util, 0, layout};
}
// A regular comparator
static bool compare(const LayoutInfo& a, const LayoutInfo& b) {
// Single wave is always better
if ((a.num_waves == 1 or b.num_waves == 1) and a.num_waves != b.num_waves)
return a.num_waves < b.num_waves;
// Doing multicast is better
if (a.layout.get_cluster_size() != b.layout.get_cluster_size())
return a.layout.get_cluster_size() > b.layout.get_cluster_size();
// Smaller number of waves is better
if (a.num_waves != b.num_waves)
return a.num_waves < b.num_waves;
// Larger last wave utilization is better
if (a.last_wave_util != b.last_wave_util)
return a.last_wave_util > b.last_wave_util;
// More stages is better
// Same block M, smaller block N is better
// Same block N, smaller block M is better
if (a.layout.block_m + a.layout.block_n != b.layout.block_m + b.layout.block_n)
return a.layout.block_m + a.layout.block_n < b.layout.block_m + b.layout.block_n;
// Less shared memory C/D, more stages is better
return a.layout.block_m * a.layout.block_n < b.layout.block_m * b.layout.block_n;
}
};

View File

@@ -2,162 +2,244 @@
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.hpp>
#include <deep_gemm/common/types.cuh>
#include "common.hpp"
#include "utils.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<int> get_block_m_candidates(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const int& m) {
std::vector<int> candidates{64, 128, 256};
if ((kernel_type == KernelType::Kernel1D2D or kernel_type == KernelType::KernelNoSF) and major_a == cute::UMMA::Major::K) {
// NOTES: `block_m = 16/32` is smaller than MMA M size, should be careful in handling this
if (m <= 16) candidates.push_back(16);
if (m <= 32) candidates.push_back(32);
static std::vector<Layout> get_layout_candidates(const GemmDesc& desc) {
// Block M candidates
std::vector<int> block_m_candidates;
if (desc.gemm_type == GemmType::Normal or
desc.gemm_type == GemmType::Batched or
desc.gemm_type == GemmType::KGroupedContiguous) {
// TODO: check 256's performance
block_m_candidates = {64, 128};
// NOTES: smaller block M can avoid TMA L2 OOB bound
if (desc.m <= 16) block_m_candidates.push_back(16);
if (desc.m <= 32) block_m_candidates.push_back(32);
// BF16 output GEMM supports 256
if (desc.cd_dtype != torch::kFloat)
block_m_candidates.push_back(256);
} else if (desc.gemm_type == GemmType::MGroupedContiguous or
desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) {
block_m_candidates = std::vector{heuristics_runtime->get_mk_alignment_for_contiguous_layout()};
} else if (desc.gemm_type == GemmType::MGroupedMasked) {
block_m_candidates = {64, 128};
}
return candidates;
}
static std::vector<int> get_block_n_candidates(const KernelType& kernel_type, const at::ScalarType& cd_dtype) {
int start = 16;
// Block N candidates
std::vector<int> block_n_candidates;
int step = std::lcm(16, heuristics_runtime->get_block_n_multiple_of());
int start = step;
// Avoid bank conflicts for 1D1D kernel FP32 output
std::vector<int> candidates;
if (kernel_type == KernelType::Kernel1D1D and cd_dtype == torch::kFloat) {
candidates.push_back(16);
if (desc.kernel_type == KernelType::Kernel1D1D and desc.cd_dtype == torch::kFloat) {
DG_HOST_ASSERT(desc.major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(desc.major_b == cute::UMMA::Major::K);
start = 24;
block_n_candidates.push_back(16);
}
// Register spills
int end = 256;
if (desc.kernel_type == KernelType::Kernel1D2D)
end = 192;
if (desc.kernel_type == KernelType::Kernel1D1D)
end = 160;
// Enumerate
for (int i = start; i <= end; i += step)
block_n_candidates.push_back(i);
// Block K is always in a fixed manner
const int block_k = 128 / get_element_size(desc.get_mma_kind());
// Disable multicast for performance
const bool disable_multicast =
// The number of k-groups is large (a heuristic)
(desc.gemm_type == GemmType::KGroupedContiguous and desc.num_groups > 4) or
// Not supported
(desc.gemm_type == GemmType::Batched);
// Enumerate all candidates
std::vector<Layout> candidates;
for (int cluster_m = 1; cluster_m <= (disable_multicast ? 1 : 2); ++ cluster_m) {
for (int cluster_n = 1; cluster_n <= (disable_multicast ? 1 : 2); ++ cluster_n) {
// We only support cluster 2
if (cluster_m * cluster_n > 2)
continue;
// SM count must be divisible
if (desc.num_sms % (cluster_m * cluster_n) != 0)
continue;
for (int block_m: block_m_candidates) {
for (int block_n: block_n_candidates) {
// 1D2D kernel unroll requirement
if (desc.kernel_type == KernelType::Kernel1D2D and block_n > block_k and (block_n % (block_n - block_k) != 0 and block_k % (block_n - block_k) != 0))
continue;
// Multicast legality for masked layout
// TODO: add some comments about it
if ((desc.gemm_type == GemmType::MGroupedMasked or desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) and
ceil_div(desc.n, block_n) % (cluster_m * cluster_n) != 0)
continue;
// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
if (block_m > 128 and block_n > 128)
continue;
// Calculate swizzling
const auto layout = Layout{0, block_m, block_n, block_k, cluster_m, cluster_n};
const auto storage_config = get_storage_config(desc, layout);
// Make sure swizzling is large enough (32B's performance is low)
if (storage_config.swizzle_a_mode % 64 != 0 or storage_config.swizzle_b_mode % 64 != 0)
continue;
// To hide TMA latency, the stage count should be at least 3; for small matrices, at least 4
int num_stages = get_pipeline_config(desc, layout, storage_config).num_stages;
if (num_stages < 3 or (block_m * block_n < 128 * 192 and num_stages < 4))
continue;
candidates.push_back(layout);
}
}
}
}
// Push the strided options
for (int i = start; i <= 256; i += 16)
candidates.push_back(i);
DG_HOST_ASSERT(not candidates.empty());
return candidates;
}
static int get_ab_load_block_m(const MulticastConfig& multicast_config, const int& block_m) {
return block_m;
}
static int get_ab_load_block_n(const MulticastConfig& multicast_config, const int& block_n) {
return block_n;
}
static int get_cd_store_block_m(const int& block_m, const bool& single_warpgroup_sync = false) {
static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) {
constexpr int wgmma_m = 64;
return single_warpgroup_sync ? wgmma_m : block_m;
}
static int get_cd_store_block_n(const int& block_n) {
return block_n;
}
// Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms)
// TODO: support swap AB
DG_HOST_ASSERT(layout.swap_ab == 0);
const auto load_block_m = layout.block_m;
const auto load_block_n = layout.block_n;
// 1D1D kernel will do single warp-group stores
const auto store_block_m = desc.kernel_type == KernelType::Kernel1D1D ? wgmma_m : layout.block_m;
const auto store_block_n = layout.block_n;
static bool enable_cd_swizzle(const at::ScalarType& cd_dtype) {
return cd_dtype != torch::kFloat;
}
static bool is_block_size_legal(const KernelType& kernel_type,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
// SM90 FP32 output does not support `block_m == 256`
if (cd_dtype == at::kFloat and block_m == 256)
return false;
// Avoid large C/D shared memory for FP32 output
// Ensure `num_stages >= 4` (for 1D1D Kernel), `num_stages >= 3` (for No SF kernel)
if (block_n > 128 and cd_dtype == torch::kFloat) {
if (kernel_type == KernelType::Kernel1D1D and block_n > 152)
return false;
if (kernel_type == KernelType::KernelNoSF and block_n > 200)
return false;
}
// When B is N Major, use swizzle 128B for better performance; only affects SM90 BF16 GEMM
if (major_b == cute::UMMA::Major::MN and block_n >= 128 and block_n % 64 != 0)
return false;
// Too many scaling factors in a single block: `block_n > block_k and std::gcd(block_n, block_k) != block_n - block_k`
// Or too many register spills
if (block_n > 128 and kernel_type == KernelType::Kernel1D2D and (block_n != 144 and block_n != 160 and block_n != 192))
return false;
// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
return block_m <= 128 or block_n <= 128;
}
static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype,
const int& num_stages,
const int& block_m, const int& block_n, const int& block_k) {
// Unrolling both stages and `num_former_iters` will cause large code size
if (mma_kind == MmaKind::MXFP8FP4 and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4)
return num_stages <= 4;
return true;
}
static std::pair<bool, bool> get_multicast_legality(const GemmType& gemm_type, const int& num_groups,
const int& m, const int& n, const int& block_m, const int& block_n,
const int& num_sms) {
// Disable multicast when the number of k-groups is large (a heuristic)
if (gemm_type == GemmType::KGroupedContiguous and num_groups > 4)
return {false, false};
if (gemm_type == GemmType::Batched)
return {false, false};
// Decide swizzling by the inner dim
const auto swizzle_mode_a = get_swizzle_mode(
desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype));
const auto swizzle_mode_b = get_swizzle_mode(
desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype));
// We only enable swizzling for non-FP32 outputs
const auto swizzle_mode_cd = desc.cd_dtype != torch::kFloat ?
get_swizzle_mode(store_block_n, c10::elementSize(desc.cd_dtype)) : 0;
return {
is_multicast_legal(n, block_n, 2, num_sms, gemm_type == GemmType::MGroupedMasked),
// For masked GEMM layout, divisibility on N is also required as we must ensure the total number of blocks is even
is_multicast_legal(m, block_m, 2, num_sms, false)
and (gemm_type != GemmType::MGroupedMasked or is_multicast_legal(n, block_n, 2, num_sms, true))
load_block_m, load_block_n,
store_block_m, store_block_n,
swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd
};
}
static ThreadConfig get_thread_config(const KernelType& kernel_type,
const int& block_m, const int& block_n) {
return ThreadConfig::sm90(128, (block_m <= 64 ? 1 : 2) * 128);
}
static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) {
constexpr int kNumMaxStages = 16;
static int get_smem_cd_size(const KernelType& kernel_type,
const int& block_m, const int& block_n,
const int& swizzle_cd_mode, const at::ScalarType& cd_dtype) {
// TODO: consider swap AB
// C/D for TMA stores
// NOTES: 1024 is for TMA swizzling alignment requirement
return align(block_m * block_n * static_cast<int>(c10::elementSize(cd_dtype)), 1024);
const int smem_cd =
align(layout.block_m * layout.block_n * static_cast<int>(c10::elementSize(desc.cd_dtype)), 1024);
const int smem_barriers = kNumMaxStages * 8 * 2;
// Calculate A/B per stages
const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype);
const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype);
// Calculate SF A/B per stages
const int smem_sfa_per_stage = desc.kernel_type == KernelType::KernelNoSF ?
0 : align(layout.block_m * static_cast<int>(sizeof(float)), 128);
const int smem_sfb_per_stage = desc.kernel_type != KernelType::Kernel1D1D ?
0 : align(layout.block_n * static_cast<int>(sizeof(float)), 128);
// Extra SFB sizes for 1D2D kernels
const int use_uniform_sfb = layout.block_k % layout.block_n == 0 ? 1 : 2;
const int smem_extra_sfb = desc.kernel_type != KernelType::Kernel1D2D ?
0 : align<int>(ceil_div(desc.k, layout.block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
// Extra tensormap for 1D1D kernels
const int smem_tensormap =
desc.gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
// Calculate stages
const int smem_extra = smem_cd + smem_barriers + smem_extra_sfb + smem_tensormap;
const int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage;
const int num_stages = std::min(
(smem_capacity - smem_extra) / smem_per_stage,
kNumMaxStages);
return {
smem_extra + num_stages * smem_per_stage,
num_stages
};
}
static std::pair<int, int> get_sf_smem_size_per_stage(const KernelType& kernel_type,
const int& block_m, const int& block_n, const int& block_k,
const MmaKind& mma_kind, const at::ScalarType& cd_dtype) {
if (mma_kind == MmaKind::BF16)
return {0, 0};
// NOTES: 128 is for 2D TMA alignment requirement
int smem_sfa_per_stage = align(block_m * static_cast<int>(sizeof(float)), 128);
int smem_sfb_per_stage = 0;
if (kernel_type == KernelType::Kernel1D1D)
smem_sfb_per_stage = align(block_n * 4, 128);
return {smem_sfa_per_stage, smem_sfb_per_stage};
static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) {
const int num_tma_threads = 128;
const int num_math_threads = layout.block_m <= 64 ? 128 : 256;
return {
desc.num_sms,
layout.get_cluster_size(),
num_tma_threads + num_math_threads,
num_tma_threads, num_math_threads,
0, 0 // Meaningless for SM90
};
}
static int get_extra_sfb_smem_size(const int& m, const int& n, const int& k,
const int& block_m, const int& block_n, const int& block_k) {
const auto& use_uniform_sfb = block_k % block_n == 0 ? 1 : 2;
return align<int>(ceil_div(k, block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) {
const auto num_blocks =
ceil_div(desc.get_expected_m(), layout.block_m) *
ceil_div(desc.get_expected_n(), layout.block_n) *
desc.get_expected_num_groups();
const auto num_waves = ceil_div(num_blocks, desc.num_sms);
const auto num_last_blocks = num_blocks % desc.num_sms;
const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks;
// Utils
const int l2_bandwidth_per_cycle = std::min(64. * desc.num_sms, 8e6 / (1.3e3)); // B/cycle
const int l1_bandwidth_per_cycle = 128 * desc.num_sms; // B/cycle
const int wgmma_m = 64;
const int elem_size_ab = c10::elementSize(desc.a_dtype);
const int elem_size_cd = c10::elementSize(desc.cd_dtype);
DG_HOST_ASSERT(desc.a_dtype == desc.b_dtype);
// Data movement per block
int64_t expected_k = desc.get_expected_k();
int64_t num_bytes_l2_ab = expected_k * (layout.block_m / layout.cluster_n + layout.block_n / layout.cluster_m) * elem_size_ab;
int64_t num_bytes_l1_ab = expected_k * (layout.block_m + layout.block_n) * elem_size_ab;
int64_t num_bytes_l1_tc = expected_k * (std::max(wgmma_m, layout.block_m) + layout.block_n) * elem_size_ab
+ layout.block_m * layout.block_n * elem_size_cd;
int64_t num_bytes_l1_l2_cd = layout.block_m * layout.block_n * elem_size_cd * (desc.with_accumulation ? 2 : 1);
// HBM bandwidth and total compute (Tensor/CUDA cores) are constant across configs
// We only model L1/L2 cycles as they are the primary variables between configs
int64_t num_l2_cycles = (num_bytes_l2_ab + num_bytes_l1_l2_cd) * num_blocks / l2_bandwidth_per_cycle;
int64_t num_l1_cycles = (num_bytes_l1_ab + num_bytes_l1_tc + num_bytes_l1_l2_cd) * num_blocks / l1_bandwidth_per_cycle;
float wave_efficiency = static_cast<float>(num_blocks) / (num_waves * desc.num_sms);
int64_t num_cycles = std::max(num_l1_cycles, num_l2_cycles) / wave_efficiency;
// Disable multicasting if only one wave exists
if (layout.cluster_n * layout.cluster_m > 1 and num_waves <= 1)
num_cycles = std::numeric_limits<int64_t>::max();
return {num_waves, last_wave_util, num_cycles, layout};
}
static int get_barrier_smem_size(const int& num_stages) {
return num_stages * 8 * 2;
}
static int get_tmem_ptr_smem_size() {
return 0;
}
static int get_tensormap_smem_size(const GemmType& gemm_type) {
return gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
// A regular comparator
static bool compare(const LayoutInfo& a, const LayoutInfo& b) {
return a.num_cycles < b.num_cycles;
}
};

View File

@@ -0,0 +1,23 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.cuh>
#include "common.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
template <typename size_type_t>
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
// `> 0` means interleaving
// 16B actually means non-swizzling (but interleaving)
for (const int& mode: {128, 64, 32, 16}) {
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
return mode;
}
DG_HOST_UNREACHABLE("Unreachable");
}
} // namespace deep_gemm

View File

@@ -6,7 +6,7 @@
namespace deep_gemm {
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
return epilogue_type.value_or("EpilogueIdentity");
return epilogue_type.value_or("epilogue::transform::EpilogueIdentity");
}
} // namespace deep_gemm

View File

@@ -20,6 +20,9 @@ static int get_non_contiguous_dim(const cute::UMMA::Major& major) {
}
static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) {
if (heuristics_runtime->get_ignore_compile_dims())
return 0;
for (const char& c: compiled_dims) {
if (name == c)
return dim;
@@ -58,8 +61,19 @@ static std::string to_string(const at::ScalarType& dtype) {
}
}
static std::string to_string(const float& v) {
if (std::isfinite(v)) {
return fmt::format(R"({:a}f)", v);
} else if (std::isinf(v)) {
return v > 0 ? "cute::numeric_limits<float>::infinity()"
: "-cute::numeric_limits<float>::infinity()";
}
DG_HOST_UNREACHABLE("NaN input is not supported");
}
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
const bool& allow_tf32) {
const bool& allow_tf32,
const bool& fp4_unpacked_smem) {
if (allow_tf32 and dtype == torch::kFloat)
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
@@ -68,13 +82,16 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B;
#if CUDA_VERSION >= 12080
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
#endif
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
#if CUDART_VERSION >= 12080
#if CUDA_VERSION >= 12080
if (base != 0) {
DG_HOST_ASSERT(base == 32 and mode == 128);
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
@@ -97,14 +114,20 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
int smem_inner_dim, int smem_outer_dim,
const int& gmem_outer_stride,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
const bool& allow_tf32 = false,
const bool& fp4_unpacked_smem = true) {
const auto elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
// Inner dim must be a multiple of 64B for .b4x16_p64
if (t.scalar_type() == kPackedFP4)
DG_HOST_ASSERT(gmem_inner_dim % 128 == 0);
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
// Fix FP4 packed smem
if (not fp4_unpacked_smem and swizzle_mode != 0)
smem_inner_dim = swizzle_mode * 2;
}
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
@@ -112,12 +135,13 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
const cuuint32_t elem_strides[2] = {1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d\n",
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d, pointer: %llu\n",
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size);
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size,
reinterpret_cast<unsigned long long>(t.data_ptr()));
}
DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem),
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
@@ -129,14 +153,20 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
int smem_dim_0, int smem_dim_1, int smem_dim_2,
const int& gmem_stride_0, const int& gmem_stride_1,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& elem_size = static_cast<int>(t.element_size());
const bool& allow_tf32 = false,
const bool& fp4_unpacked_smem = true) {
const auto elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_dim_0 = swizzle_mode / elem_size;
// Inner dim must be a multiple of 64B for .b4x16_p64
if (t.scalar_type() == kPackedFP4)
DG_HOST_ASSERT(gmem_dim_0 % 128 == 0);
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_dim_0 % 128 == 0);
// Fix fp4 packed smem
if (not fp4_unpacked_smem and swizzle_mode != 0)
smem_dim_0 = swizzle_mode * 2;
}
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
@@ -149,7 +179,7 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
}
DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32),
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem),
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
@@ -166,8 +196,8 @@ static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
const bool& allow_tf32 = false) {
if (num_groups > 1)
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim,
smem_inner_dim, smem_outer_dim,
@@ -184,8 +214,8 @@ static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto& [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto& [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
// `num_groups` is always applied into the outer dimensions
return make_tma_2d_desc(t,

View File

@@ -16,9 +16,7 @@ namespace deep_gemm {
class SM100BF16GemmRuntime final: public LaunchRuntime<SM100BF16GemmRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -45,28 +43,32 @@ static void __instantiate_kernel() {{
{}, {},
{}, {},
{},
{},
{}, {}, {},
{}
>);
}};
)",
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype),
args.gemm_config.tc_util);
to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b),
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_desc.num_groups,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms,
args.gemm_config.layout.swap_ab,
to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, to_string(args.gemm_desc.cd_dtype),
args.gemm_desc.tc_util);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout, args.m, args.n, args.k,
args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_cd));
}
@@ -79,45 +81,49 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_gemm", code);
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_gemm", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
@@ -130,53 +136,61 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
// Otherwise, treat the contiguous layout as a whole.
const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m;
const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1;
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto& config = get_best_config<SM100ArchSpec>(
gemm_type, KernelType::KernelNoSF,
// NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole
m_for_config, n, k, num_groups_for_config, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
const SM100BF16GemmRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code);
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
@@ -188,45 +202,50 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::KernelNoSF,
expected_m, n, k, num_groups, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
const SM100BF16GemmRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code);
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
@@ -241,54 +260,59 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a,
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0;
for (const auto& k: ks) {
for (const auto k: ks) {
sum_k += k;
DG_HOST_ASSERT(k % 128 == 0);
}
const auto& num_groups = static_cast<int>(ks.size());
const auto num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto& max_k = *std::max_element(ks.begin(), ks.end());
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::KGroupedContiguous, KernelType::KernelNoSF,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto max_k = *std::max_element(ks.begin(), ks.end());
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(0)), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(0)), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(0)), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(0)), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(1)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch kernel
const SM100BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_k_grouped_gemm", code);
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_k_grouped_gemm", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
@@ -297,46 +321,46 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = d, .k = r, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h,
config.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.smem_config.swizzle_a_mode);
const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.block_k, load_block_n, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.smem_config.swizzle_b_mode);
const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m);
const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n);
const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.smem_config.swizzle_cd_mode);
const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h,
config.layout.block_k, config.storage_config.load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.layout.block_k, config.storage_config.load_block_n, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h,
config.storage_config.store_block_n, config.storage_config.store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = b, .n = d, .k = r,
.num_groups = h,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code);
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
@@ -345,46 +369,46 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = r, .k = d, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h,
config.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.smem_config.swizzle_a_mode);
const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
load_block_n, config.block_k, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.smem_config.swizzle_b_mode);
const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m);
const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n);
const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.smem_config.swizzle_cd_mode);
const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h,
config.layout.block_k, config.storage_config.load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.storage_config.load_block_n, config.layout.block_k, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h,
config.storage_config.store_block_n, config.storage_config.store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.m = b, .n = r, .k = d,
.num_groups = h,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code);
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code);
SM100BF16GemmRuntime::launch(runtime, args);
}

View File

@@ -85,11 +85,11 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
// NOTES: we select 4 as start, as it is tested to be faster than values > 4
int num_stages = 4, smem_size = 0;
while (true) {
const int& smem_cd = block_m * swizzle_cd_mode * 2;
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_barrier = SM100ArchSpec::get_barrier_smem_size(num_stages);
const int& smem_tmem_ptr = SM100ArchSpec::get_tmem_ptr_smem_size();
const int smem_cd = block_m * swizzle_cd_mode * 2;
const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int smem_barrier = num_stages * 8 * 3 + 2 * 8 * 2 + 8;
const int smem_tmem_ptr = 4;
smem_size = 0;
smem_size += smem_cd;
@@ -112,11 +112,11 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode);
}
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const auto& tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const auto tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
const SM100BmkBnkMnRuntime::Args& args = {
const SM100BmkBnkMnRuntime::Args args = {
.s = s, .m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.split_factor = split_factor,
@@ -129,8 +129,8 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d
};
const auto& code = SM100BmkBnkMnRuntime::generate(args);
const auto& runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
const auto code = SM100BmkBnkMnRuntime::generate(args);
const auto runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
SM100BmkBnkMnRuntime::launch(runtime, args);
}

View File

@@ -0,0 +1,459 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8FP4Gemm1D1DRuntime> {
public:
struct Args {
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
// TODO: move into descriptor
const std::optional<std::string> epilogue_type;
// TODO: move into descriptor
int gran_k_a, gran_k_b;
void* grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_cd;
};
static std::string generate_impl(const Args& args) {
// TODO: rename files
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_fp4_gemm_1d1d_impl<
{}, {},
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{},
{}, {},
{}, {}, {},
{}
>);
}};
)",
to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b),
args.gran_k_a, args.gran_k_b,
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_desc.num_groups,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms,
args.gemm_config.layout.swap_ab,
to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation,
to_string(args.gemm_desc.a_dtype), to_string(args.gemm_desc.b_dtype), to_string(args.gemm_desc.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
}
};
static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto cd = c.value_or(d);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, 1, 0);
// Launch
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = epilogue_type,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const int& num_groups, const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
// Otherwise, treat the contiguous layout as a whole.
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, num_groups, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const int& gran_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
const int gran_k_a = gran_k;
const int gran_k_b = gran_k;
int sum_k = 0, sum_sf_k = 0;
for (const auto k: ks) {
sum_k += k, sum_sf_k += ceil_div(k, gran_k * 4);
DG_HOST_ASSERT(k % gran_k == 0);
}
const auto num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto max_k = *std::max_element(ks.begin(), ks.end());
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(0)), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(0)), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(1)), num_groups,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * gran_k_a * 4,
config.layout.block_m, gran_k_a, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * gran_k_b * 4,
config.layout.block_n, gran_k_b, 1, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& batch_size, const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = batch_size,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const int load_block_m = config.storage_config.load_block_m;
const auto [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m);
const auto [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.layout.block_k, load_block_m);
const auto tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size,
inner_block_a, outer_block_a, 1,
a.stride(major_a == cute::UMMA::Major::K ? 1 : 2),
a.stride(0),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n);
const auto [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.layout.block_k, load_block_n);
const auto tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size,
inner_block_b, outer_block_b, 1,
b.stride(major_b == cute::UMMA::Major::K ? 1 : 2),
b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size,
store_block_n, store_block_m, 1,
d.stride(1), d.stride(0),
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, batch_size, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, batch_size, 0);
// Launch
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,210 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "runtime_utils.hpp"
#include <deep_gemm/layout/mega_moe.cuh>
#include <deep_gemm/layout/sym_buffer.cuh>
#include "../heuristics/mega_moe.hpp"
namespace deep_gemm {
class SM100FP8FP4MegaMoERuntime final : public LaunchRuntime<SM100FP8FP4MegaMoERuntime> {
public:
struct Args {
// Templated arguments
int num_max_tokens_per_rank;
int hidden, intermediate_hidden;
int num_experts, num_topk;
int num_ranks;
float activation_clamp;
bool fast_math;
MegaMoEConfig config;
// Runtime arguments
void* y;
int num_tokens;
layout::SymBuffer<> sym_buffer_ptrs;
// Tensormap
CUtensorMap tensor_map_l1_acts;
CUtensorMap tensor_map_l1_acts_sf;
CUtensorMap tensor_map_l1_weights;
CUtensorMap tensor_map_l1_weights_sf;
CUtensorMap tensor_map_l1_output;
CUtensorMap tensor_map_l2_acts;
CUtensorMap tensor_map_l2_acts_sf;
CUtensorMap tensor_map_l2_weights;
CUtensorMap tensor_map_l2_weights_sf;
// Launch configs
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_fp4_mega_moe_impl<
{},
{}, {},
{}, {},
{},
{}, {}, {},
{},
{}, {},
{},
{},
{},
{}, {}, {},
{}, {},
{},
{}
>);
}};
)", args.num_max_tokens_per_rank,
args.hidden, args.intermediate_hidden,
args.num_experts, args.num_topk,
args.config.num_experts_per_wave,
args.config.block_m, args.config.block_n, args.config.block_k,
args.config.store_block_m,
args.config.sf_block_m, args.config.sf_block_n,
args.config.num_max_pool_tokens,
args.config.num_padded_sf_pool_tokens,
args.config.num_stages,
args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads,
args.launch_args.grid_dim.first, args.num_ranks,
to_string(args.activation_clamp),
args.fast_math ? "true" : "false");
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.y,
args.num_tokens,
args.sym_buffer_ptrs,
args.tensor_map_l1_acts,
args.tensor_map_l1_acts_sf,
args.tensor_map_l1_weights,
args.tensor_map_l1_weights_sf,
args.tensor_map_l1_output,
args.tensor_map_l2_acts,
args.tensor_map_l2_acts_sf,
args.tensor_map_l2_weights,
args.tensor_map_l2_weights_sf
));
}
};
static void sm100_fp8_fp4_mega_moe(
const torch::Tensor& y,
const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf,
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
const std::vector<int64_t>& sym_buffer_ptrs,
const int& rank_idx, const int& num_max_tokens_per_rank,
const int& num_experts_per_rank,
const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const float& activation_clamp,
const bool& fast_math
) {
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts = num_experts_per_rank * num_ranks;
// Heuristics
const auto config = get_mega_moe_config(
num_ranks, num_experts, num_experts_per_rank,
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden);
// Make tensormap
constexpr int kGranK = 32;
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode);
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
config.sf_block_m, kGranK,
1, 0);
const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights,
hidden, num_experts_per_rank * intermediate_hidden * 2,
config.block_k, config.load_block_n,
static_cast<int>(l1_weights.stride(-2)),
config.swizzle_weights_mode);
const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf,
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// NOTES: L1 output and L2 activations are essentially the same tensor.
// Post-SwiGLU output has half the N width (`BLOCK_N / 2` per input tile),
// so the swizzle mode is also halved (128 -> 64).
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode);
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,
1, 0);
const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights,
intermediate_hidden, num_experts_per_rank * hidden,
config.block_k, config.load_block_n,
static_cast<int>(l2_weights.stride(-2)),
config.swizzle_weights_mode);
const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf,
hidden, intermediate_hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// Launch
const auto num_sms = device_runtime->get_num_sms();
const SM100FP8FP4MegaMoERuntime::Args args = {
.num_max_tokens_per_rank = num_max_tokens_per_rank,
.hidden = hidden, .intermediate_hidden = intermediate_hidden,
.num_experts = num_experts, .num_topk = num_topk,
.num_ranks = num_ranks,
.activation_clamp = activation_clamp,
.fast_math = fast_math,
.config = config,
.y = y.data_ptr(),
.num_tokens = num_tokens,
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
.tensor_map_l1_acts = tensor_map_l1_acts,
.tensor_map_l1_acts_sf = tensor_map_l1_acts_sf,
.tensor_map_l1_weights = tensor_map_l1_weights,
.tensor_map_l1_weights_sf = tensor_map_l1_weights_sf,
.tensor_map_l1_output = tensor_map_l1_output,
.tensor_map_l2_acts = tensor_map_l2_acts,
.tensor_map_l2_acts_sf = tensor_map_l2_acts_sf,
.tensor_map_l2_weights = tensor_map_l2_weights,
.tensor_map_l2_weights_sf = tensor_map_l2_weights_sf,
.launch_args = LaunchArgs(num_sms,
config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads,
config.smem_size, 2)
};
const auto code = SM100FP8FP4MegaMoERuntime::generate(args);
const auto runtime = compiler->build("sm100_fp8_fp4_mega_moe", code);
SM100FP8FP4MegaMoERuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -79,21 +79,21 @@ static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a,
DG_HOST_ASSERT(n <= 128 and n % 8 == 0);
DG_HOST_ASSERT(k % block_k == 0);
const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
: make_tma_3d_desc(d, n, m, num_splits,
block_n, block_m, 1,
static_cast<int>(d.stride(-2)),
@@ -135,14 +135,14 @@ static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a,
.num_stages = num_stages,
.num_mma_threads = num_mma_threads,
.num_cast_and_reduce_threads = num_cast_and_reduce_threads,
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1),
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.sqr_sum = sqr_sum.data_ptr<float>()
};
const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args);
const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code);
const auto code = SM100BF16HCPrenormGemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code);
SM100BF16HCPrenormGemmRuntime::launch(runtime, args);
}

View File

@@ -14,9 +14,7 @@ namespace deep_gemm {
class SM90BF16GemmRuntime final: public LaunchRuntime<SM90BF16GemmRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -49,24 +47,29 @@ static void __instantiate_kernel() {{
}};
)",
// TODO: add CD dtype
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation,
to_string(args.gemm_config.cd_dtype));
to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b),
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_desc.num_groups,
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_config.storage_config.swizzle_a_mode,
args.gemm_config.storage_config.swizzle_b_mode,
args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads,
// TODO: refactor with cluster M/N
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms,
to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation,
to_string(args.gemm_desc.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout,
args.m, args.n, args.k,
args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_cd));
}
@@ -79,46 +82,50 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_gemm", code);
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
@@ -128,51 +135,67 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedContiguous, KernelType::KernelNoSF,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
@@ -188,46 +211,51 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedMasked, KernelType::KernelNoSF,
expected_m, n, k, num_groups, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = 0, .expected_k = 0, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
@@ -242,54 +270,59 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0;
for (const auto& k: ks) {
for (const auto k: ks) {
sum_k += k;
DG_HOST_ASSERT(k % 128 == 0);
}
const auto& num_groups = static_cast<int>(ks.size());
const auto num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto& max_k = *std::max_element(ks.begin(), ks.end());
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::KGroupedContiguous, KernelType::KernelNoSF,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto max_k = *std::max_element(ks.begin(), ks.end());
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(0)), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(0)), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(0)), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(0)), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(1)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch kernel
const SM90BF16GemmRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_k_grouped_gemm", code);
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_k_grouped_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
@@ -298,45 +331,50 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = d, .k = r, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
const auto& tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h,
config.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.smem_config.swizzle_a_mode);
const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.block_k, load_block_n, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.smem_config.swizzle_b_mode);
const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m);
const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n);
const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h,
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h,
config.layout.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.layout.block_k, load_block_n, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.smem_config.swizzle_cd_mode);
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = b, .n = d, .k = r,
.num_groups = h,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code);
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
@@ -345,45 +383,49 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Batched, KernelType::KernelNoSF,
b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN,
tensor_a.scalar_type(), tensor_b.scalar_type(),
tensor_d.scalar_type(), false,
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = r, .k = d, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
const auto& tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h,
config.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.smem_config.swizzle_a_mode);
const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
const auto& tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
load_block_n, config.block_k, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.smem_config.swizzle_b_mode);
const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m);
const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n);
const auto& tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h,
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h,
config.layout.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
load_block_n, config.layout.block_k, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.smem_config.swizzle_cd_mode);
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.m = b, .n = r, .k = d,
.num_groups = h,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90BF16GemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code);
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code);
SM90BF16GemmRuntime::launch(runtime, args);
}

View File

@@ -84,9 +84,9 @@ static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
// Select best number of stages
int num_stages = 4, smem_size = 0;
while (true) {
const int& smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int& smem_barrier = SM90ArchSpec::get_barrier_smem_size(num_stages);
const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int smem_barrier = num_stages * 8 * 2;
smem_size = 0;
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
@@ -108,8 +108,8 @@ static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
num_stages, smem_size, swizzle_ab_mode);
}
const auto& tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto& tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const SM90BmkBnkMnRuntime::Args& args = {
.s = s, .m = m, .n = n, .k = k,
@@ -123,8 +123,8 @@ static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
.tensor_map_b = tensor_map_b,
.d = d.data_ptr<float>()
};
const auto& code = SM90BmkBnkMnRuntime::generate(args);
const auto& runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
const auto code = SM90BmkBnkMnRuntime::generate(args);
const auto runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
SM90BmkBnkMnRuntime::launch(runtime, args);
}

View File

@@ -15,9 +15,7 @@ namespace deep_gemm {
class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime<SM90FP8Gemm1D1DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
const std::string& compiled_dims;
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
@@ -52,15 +50,17 @@ static void __instantiate_kernel() {{
>);
}};
)",
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
to_string(args.gemm_config.cd_dtype));
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_desc.num_groups,
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type),
to_string(args.gemm_desc.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -68,7 +68,7 @@ static void __instantiate_kernel() {{
args.gmem_a_ptr, args.gmem_b_ptr,
args.grouped_layout,
args.tensor_map_buffer,
args.m, args.n, args.k,
args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a_base, args.tensor_map_b_base,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
@@ -85,44 +85,48 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k, k, 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k, k, 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, 1, 0);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
0);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k, k, 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k, k, 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, config.layout.block_k, 1, 0);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
0);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.gmem_a_ptr = nullptr,
.gmem_b_ptr = nullptr,
.grouped_layout = nullptr,
@@ -133,8 +137,8 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
const auto code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
@@ -151,54 +155,61 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
// Get config using max K for better performance
const auto& num_groups = static_cast<int>(ks.size());
const auto& max_k = *std::max_element(ks.begin(), ks.end());
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
int first_k = 0, sum_k = 0, sum_sf_k = 0;
// TODO: refactor with the mk alignment function
const auto num_groups = static_cast<int>(ks.size());
int first_k = 0, sum_k = 0, sum_sf_k = 0, max_k = 0;
for (int i = 0; i < num_groups; ++ i) {
if (first_k == 0 and ks[i] != 0)
first_k = ks[i];
sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128);
max_k = std::max(max_k, ks[i]);
DG_HOST_ASSERT(ks[i] % 128 == 0);
}
const auto& tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k, first_k, 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k, first_k, 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
config.block_n, config.block_k, 1, 0);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m, true),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
// Get config using max K for better performance
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
config.storage_config.load_block_m,
config.layout.block_k, first_k, 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
config.storage_config.load_block_n,
config.layout.block_k, first_k, 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
config.layout.block_m, config.layout.block_k, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
config.layout.block_n, config.layout.block_k, 1, 0);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.gmem_a_ptr = a.data_ptr(),
.gmem_b_ptr = b.data_ptr(),
.grouped_layout = ks_tensor.data_ptr(),
@@ -209,8 +220,8 @@ static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Te
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd,
};
const auto& code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
const auto code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}

View File

@@ -17,14 +17,13 @@ namespace deep_gemm {
class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime> {
public:
struct Args {
cute::UMMA::Major major_sfb;
int m, n, k, num_groups;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
// TODO: move this into `gemm_desc`
const std::optional<std::string>& epilogue_type;
cute::UMMA::Major major_sfb;
void *sfb, *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
@@ -45,7 +44,7 @@ static void __instantiate_kernel() {{
{},
{}, {}, {},
{}, {}, {},
{}, {},
{},
{}, {},
{}, {},
{}, {},
@@ -55,14 +54,16 @@ static void __instantiate_kernel() {{
)",
// TODO: add CD dtype
to_string(args.major_sfb),
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.num_groups,
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages, args.gemm_config.num_last_stages,
args.gemm_config.thread_config.num_tma_threads, args.gemm_config.thread_config.num_math_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms, to_string(args.gemm_config.gemm_type),
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_desc.num_groups,
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type),
get_default_epilogue_type(args.epilogue_type));
}
@@ -70,7 +71,7 @@ static void __instantiate_kernel() {{
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.sfb, args.grouped_layout,
args.m, args.n, args.k,
args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
}
@@ -87,45 +88,49 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Normal, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = k,
.num_groups = 1,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = epilogue_type,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
@@ -133,8 +138,8 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
@@ -144,49 +149,65 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
const std::string& compiled_dims) {
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedContiguous, KernelType::Kernel1D2D,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, 1, 0);
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_d = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
@@ -194,8 +215,8 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
@@ -210,45 +231,50 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D2D,
expected_m, n, k, num_groups, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_d = make_tma_cd_desc(d, m, n,
SM90ArchSpec::get_cd_store_block_m(config.block_m),
SM90ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, num_groups, 0);
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_d = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, num_groups, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
@@ -256,8 +282,8 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
@@ -271,51 +297,55 @@ static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto& config = get_best_config<SM90ArchSpec>(
GemmType::Batched, KernelType::Kernel1D2D,
m, n, k, batch_size, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = batch_size,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k);
DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k);
const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
const auto& tensor_map_a = make_tma_3d_desc(a, k, m, batch_size,
config.block_k, load_block_m, 1,
a.stride(1),
a.stride(0),
config.smem_config.swizzle_a_mode);
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(a, k, m, batch_size,
config.layout.block_k, load_block_m, 1,
a.stride(1),
a.stride(0),
config.storage_config.swizzle_a_mode);
const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
const auto& tensor_map_b = make_tma_3d_desc(b, k, n, batch_size,
config.block_k, load_block_n, 1,
b.stride(1),
b.stride(0),
config.smem_config.swizzle_b_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(b, k, n, batch_size,
config.layout.block_k, load_block_n, 1,
b.stride(1),
b.stride(0),
config.storage_config.swizzle_b_mode);
const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m);
const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n);
const auto& tensor_map_d = make_tma_3d_desc(d, n, m, batch_size,
store_block_n, store_block_m, 1,
d.stride(1), d.stride(0),
config.smem_config.swizzle_cd_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_d = make_tma_3d_desc(d, n, m, batch_size,
store_block_n, store_block_m, 1,
d.stride(1), d.stride(0),
config.storage_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, batch_size, 0);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, batch_size, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.major_sfb = major_sfb,
.m = m, .n = n, .k = k,
.num_groups = batch_size,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
@@ -323,8 +353,8 @@ static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}

View File

@@ -81,21 +81,21 @@ static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a,
DG_HOST_ASSERT(n <= 32 and n % 8 == 0);
DG_HOST_ASSERT(k % block_k == 0);
const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
: make_tma_3d_desc(d, n, m, num_splits,
block_n, block_m, 1,
static_cast<int>(d.stride(-2)),
@@ -138,14 +138,14 @@ static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a,
.num_stages = num_stages,
.num_math_threads = num_math_threads,
.num_tma_threads = num_tma_threads,
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1),
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.sqr_sum = sqr_sum.data_ptr<float>()
};
const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args);
const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code);
const auto code = SM90BF16HCPrenormGemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code);
SM90BF16HCPrenormGemmRuntime::launch(runtime, args);
}

View File

@@ -17,7 +17,8 @@ public:
int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
float* logits;
void* logits;
at::ScalarType logits_dtype;
int block_kv;
int num_warps;
@@ -33,10 +34,10 @@ using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&smxx_clean_logits<
{}, {}, {}
{}, {}, {}, {}
>);
}};
)", args.next_n, args.block_kv, args.num_warps);
)", args.next_n, args.block_kv, args.num_warps, to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
@@ -65,14 +66,15 @@ static void smxx_clean_logits(const torch::Tensor& logits,
.stride_logits = stride_logits,
.cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr<int>() : nullptr,
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr<float>(),
.logits = logits.data_ptr(),
.logits_dtype = logits.scalar_type(),
.block_kv = block_kv,
.num_warps = num_warps,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_warps * 32, smem_size)
};
const auto& code = SMXXCleanLogitsRuntime::generate(args);
const auto& runtime = compiler->build("smxx_clean_logits", code);
const auto code = SMXXCleanLogitsRuntime::generate(args);
const auto runtime = compiler->build("smxx_clean_logits", code);
SMXXCleanLogitsRuntime::launch(runtime, args);
}

View File

@@ -46,7 +46,7 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE
const int& math_sms = device_runtime->get_num_sms();
const int math_sms = device_runtime->get_num_sms();
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
#endif
@@ -57,10 +57,10 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
#endif
// Get cuBLASLt handle, workspace, and stream
const auto& handle = device_runtime->get_cublaslt_handle();
const auto& workspace = device_runtime->get_cublaslt_workspace();
const auto& workspace_bytes = workspace.nbytes();
const auto& stream = at::cuda::getCurrentCUDAStream();
const auto handle = device_runtime->get_cublaslt_handle();
const auto workspace = device_runtime->get_cublaslt_workspace();
const auto workspace_bytes = workspace.nbytes();
const auto stream = at::cuda::getCurrentCUDAStream();
// Algorithm selection
cublasLtMatmulPreference_t pref;
@@ -77,7 +77,7 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
// Call: D = alpha * (A @ B) + beta * C
const float& alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
const float alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
desc, // Operation description
&alpha, // Alpha
@@ -99,47 +99,36 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a,
}
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
const std::optional<torch::Tensor>& acc,
const torch::Tensor& out,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major) {
const auto& trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto& trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
// Duplicate the accumulator if necessary
// TODO: remove this
if (acc.has_value()) {
if (acc->data_ptr() == out.data_ptr()) {
DG_HOST_ASSERT(acc->sizes() == out.sizes() and acc->strides() == out.strides());
} else {
out.copy_(acc.value());
}
}
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major,
const bool& accumulate) {
const auto trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
// Matrix layouts
const auto& cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
const auto& cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
const auto& cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
const auto& layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
const auto& layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
const auto cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
const auto cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
const auto cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
const auto layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
const auto layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
const auto layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, acc.has_value());
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, accumulate);
}
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = d, n = b, k = r;
const auto& trans_a = CUBLAS_OP_T;
const auto& trans_b = CUBLAS_OP_N;
const auto m = d, n = b, k = r;
const auto trans_a = CUBLAS_OP_T;
const auto trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
@@ -147,14 +136,14 @@ static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor&
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto& m = r, n = b, k = d;
const auto& trans_a = CUBLAS_OP_N;
const auto& trans_b = CUBLAS_OP_N;
const auto m = r, n = b, k = d;
const auto trans_a = CUBLAS_OP_N;
const auto trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto& layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
const auto& layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto& layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}

View File

@@ -0,0 +1,328 @@
#pragma once
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../heuristics/sm90.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SMXXFP8MQALogitsRuntime final: public LaunchRuntime<SMXXFP8MQALogitsRuntime> {
public:
struct Args {
int seq_len;
int seq_len_kv;
int max_seqlen_k;
int stride_logits;
int num_heads, head_dim;
bool is_compressed_logits;
int num_q_stages;
int num_kv_stages;
int block_q;
int block_kv;
int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
void* logits;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_scales;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm{}_fp8_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_mqa_logits<
{}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)", arch, arch,
args.num_heads, args.head_dim,
args.is_compressed_logits,
args.block_q, args.block_kv,
args.num_q_stages, args.num_kv_stages,
args.launch_args.grid_dim.first,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.seq_len, args.seq_len_kv,
args.max_seqlen_k, args.stride_logits,
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
args.logits,
args.tensor_map_q, args.tensor_map_kv,
args.tensor_map_kv_scales, args.tensor_map_weights
));
}
};
static void smxx_fp8_mqa_logits(const torch::Tensor& q,
const torch::Tensor& kv, const torch::Tensor& kv_scales,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const torch::Tensor& logits,
const at::ScalarType& logits_dtype,
const int& seq_len, const int& seq_len_kv,
const int& max_seqlen_k, const int& stride_logits,
const int& num_heads, const int& head_dim,
const int& block_q, const int& block_kv) {
constexpr int num_specialized_threads = 128;
constexpr int num_q_stages = 3, num_kv_stages = 3;
const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512);
// Use compressed logits format when max_seqlen_k is specified
const bool is_compressed_logits = (max_seqlen_k > 0);
// Construct TMAs
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
head_dim, block_q * num_heads, head_dim, head_dim);
const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
head_dim, block_kv, head_dim, head_dim);
// According to the driver API, the minimal alignment is 256 bytes
// So it is safe for us to do a 16-byte OOB
const auto tensor_map_kv_scales = make_tma_2d_desc(kv_scales,
get_tma_aligned_size(seq_len_kv, static_cast<int>(kv_scales.element_size())),
1, block_kv, 1, 0, 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
num_heads, block_q, num_heads, 0);
// Calculate shared memory size
int smem_size = 0;
const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(weights.element_size());
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv.element_size());
const int kv_scale_size_per_stage = block_kv * static_cast<int>(kv_scales.element_size());
smem_size += num_q_stages * smem_q_size_per_stage;
smem_size += num_kv_stages * smem_kv_size_per_stage;
smem_size += num_q_stages * smem_weight_size_per_stage;
smem_size += num_kv_stages * kv_scale_size_per_stage;
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8;
smem_size += 4;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SMXXFP8MQALogitsRuntime::Args args = {
.seq_len = seq_len,
.seq_len_kv = seq_len_kv,
.max_seqlen_k = max_seqlen_k,
.stride_logits = stride_logits,
.num_heads = num_heads, .head_dim = head_dim,
.is_compressed_logits = is_compressed_logits,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.block_q = block_q,
.block_kv = block_kv,
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr(),
.tensor_map_q = tensor_map_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_kv_scales = tensor_map_kv_scales,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SMXXFP8MQALogitsRuntime::generate(args);
const auto runtime = compiler->build("smxx_fp8_mqa_logits", code);
SMXXFP8MQALogitsRuntime::launch(runtime, args);
}
class SM100FP4MQALogitsRuntime final: public LaunchRuntime<SM100FP4MQALogitsRuntime> {
public:
struct Args {
int seq_len;
int seq_len_kv;
int max_seqlen_k;
int stride_logits;
int num_heads, head_dim;
bool is_compressed_logits;
int num_q_stages;
int num_kv_stages;
int block_q;
int block_kv;
int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
void* logits;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_sf_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_sf_kv;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp4_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp4_mqa_logits<
{}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)", args.num_heads, args.head_dim,
args.is_compressed_logits,
args.block_q, args.block_kv,
args.num_q_stages, args.num_kv_stages,
args.launch_args.grid_dim.first,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.seq_len, args.seq_len_kv,
args.max_seqlen_k, args.stride_logits,
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
args.logits,
args.tensor_map_q, args.tensor_map_sf_q,
args.tensor_map_kv, args.tensor_map_sf_kv,
args.tensor_map_weights
));
}
};
static void sm100_fp4_mqa_logits(const torch::Tensor& q, const torch::Tensor& sf_q,
const torch::Tensor& kv, const torch::Tensor& sf_kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const torch::Tensor& logits,
const at::ScalarType& logits_dtype,
const int& seq_len, const int& seq_len_kv,
const int& max_seqlen_k, const int& stride_logits,
const int& num_heads, const int& head_dim,
const int& block_q, const int& block_kv) {
constexpr int num_specialized_threads = 128;
const int num_math_threads = 2 * 128;
constexpr int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3;
// Use compressed logits format when max_seqlen_k is specified
const bool is_compressed_logits = (max_seqlen_k > 0);
// Construct TMAs
// `head_dim` must be 128 for 64B swizzling
DG_HOST_ASSERT(head_dim == 128);
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
head_dim, block_q * num_heads,
static_cast<int>(q.stride(1)),
head_dim / 2, 0, false, false);
const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, seq_len,
num_heads, block_q,
static_cast<int>(sf_q.stride(0)), 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
num_heads, block_q,
static_cast<int>(weights.stride(0)), 0);
const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
head_dim, block_kv,
static_cast<int>(kv.stride(0)),
head_dim / 2, 0, false, false);
// According to the driver API, the minimal alignment is 256 bytes
// So it is safe for us to do a 16-byte OOB
const auto tensor_map_sf_kv = make_tma_2d_desc(sf_kv,
get_tma_aligned_size(seq_len_kv, static_cast<int>(sf_kv.element_size())), 1,
block_kv, 1, 0, 0);
// Calculate shared memory size
const int smem_q_size_per_stage = block_q * num_heads * head_dim / 2;
const int smem_sf_q_size_per_stage = align(block_q * num_heads, 128) * sizeof(int);
const int smem_kv_size_per_stage = block_kv * head_dim / 2;
const int smem_sf_kv_size_per_stage = align(block_kv, 128) * sizeof(int);
const int smem_weight_size_per_stage = block_q * num_heads * sizeof(float);
const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8;
const int smem_tmem_ptr = 4;
const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) +
smem_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SM100FP4MQALogitsRuntime::Args args = {
.seq_len = seq_len,
.seq_len_kv = seq_len_kv,
.max_seqlen_k = max_seqlen_k,
.stride_logits = stride_logits,
.num_heads = num_heads, .head_dim = head_dim,
.is_compressed_logits = is_compressed_logits,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.block_q = block_q,
.block_kv = block_kv,
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr(),
.tensor_map_q = tensor_map_q,
.tensor_map_sf_q = tensor_map_sf_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_sf_kv = tensor_map_sf_kv,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SM100FP4MQALogitsRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp4_mqa_logits", code);
SM100FP4MQALogitsRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,445 @@
#pragma once
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQALogitsMetadataRuntime> {
public:
struct Args {
int aligned_batch_size;
int split_kv;
int num_sms;
int batch_size;
int next_n;
bool is_context_lens_2d;
int* context_lens;
int* schedule_metadata;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sched::smxx_paged_mqa_logits_metadata<
{}, {}, {}
>);
}};
)", args.aligned_batch_size, args.split_kv, args.num_sms);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
args.next_n,
args.is_context_lens_2d,
args.context_lens,
args.schedule_metadata
));
}
};
static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
const torch::Tensor& schedule_metadata,
const int& batch_size, const int& next_n,
const int& block_kv, const int& num_sms,
const bool& is_context_lens_2d) {
constexpr int split_kv = 256;
constexpr int num_threads = 32;
const int aligned_batch_size = align(batch_size, 32);
DG_HOST_ASSERT(split_kv % block_kv == 0);
// Calculate shared memory size
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SMXXPagedMQALogitsMetadataRuntime::Args& args = {
.aligned_batch_size = aligned_batch_size,
.split_kv = split_kv,
.num_sms = num_sms,
.batch_size = batch_size,
.next_n = next_n,
.is_context_lens_2d = is_context_lens_2d,
.context_lens = context_lens.data_ptr<int>(),
.schedule_metadata = schedule_metadata.data_ptr<int>(),
.launch_args = LaunchArgs(1, num_threads, smem_size)
};
const auto code = SMXXPagedMQALogitsMetadataRuntime::generate(args);
const auto runtime = compiler->build("smxx_paged_mqa_logits_metadata", code);
SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args);
}
class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime<SMXXFP8PagedMQALogitsRuntime> {
public:
struct Args {
int batch_size;
int next_n;
int num_heads;
int head_dim;
int block_kv;
bool is_context_lens_2d;
int block_table_stride;
int logits_stride;
int num_q_stages;
int num_kv_stages;
int split_kv;
int* context_lens;
void* logits;
int* block_table;
int* schedule_meta;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_scales;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_paged_mqa_logits<
{}, {},
{}, {},
{},
{}, {},
{},
{}, {},
{}
>);
}};
)", arch, arch,
args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.is_context_lens_2d,
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
args.logits_stride, args.block_table_stride,
args.context_lens, args.logits,
args.block_table, args.schedule_meta,
args.tensor_map_q, args.tensor_map_kv,
args.tensor_map_kv_scales, args.tensor_map_weights
));
}
};
static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& kv_cache,
const torch::Tensor& kv_cache_scales,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& logits,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const at::ScalarType& logits_dtype,
const int& batch_size, const int& next_n,
const int& num_heads, const int& head_dim,
const int& num_kv_blocks, const int& block_kv,
const bool& is_context_lens_2d,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
const int& split_kv) {
const int num_specialized_threads = 128;
const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64);
const int num_math_warp_groups = split_kv / mma_m;
const int num_math_threads = num_math_warp_groups * 128;
const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3);
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);
// Construct TMAs
const int next_n_atom = (next_n % 2 == 0) ? 2 : 1;
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
head_dim, next_n_atom * num_heads,
static_cast<int>(q.stride(2)),
head_dim);
const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
head_dim, block_kv, 1,
static_cast<int>(kv_cache.stride(1)),
static_cast<int>(kv_cache.stride(0)),
head_dim);
const auto tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks,
block_kv, 1,
static_cast<int>(kv_cache_scales.stride(0)), 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n,
num_heads, next_n_atom,
static_cast<int>(weights.stride(0)), 0);
// Calculate shared memory size
int smem_size = 0;
if (device_runtime->get_arch_major() == 9) {
const int swizzle_alignment = head_dim * 8;
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(next_n == 1 or next_n == 2);
} else {
const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_kv_size_per_stage = split_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int smem_kv_scale_size_per_stage = split_kv * static_cast<int>(kv_cache_scales.element_size());
const int smem_weight_size_per_stage = next_n_atom * num_heads * static_cast<int>(weights.element_size());
const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8;
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) +
smem_barriers + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
}
// Launch
const SMXXFP8PagedMQALogitsRuntime::Args args = {
.batch_size = batch_size,
.next_n = next_n,
.num_heads = num_heads,
.head_dim = head_dim,
.block_kv = block_kv,
.is_context_lens_2d = is_context_lens_2d,
.block_table_stride = block_table_stride,
.logits_stride = logits_stride,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.split_kv = split_kv,
.context_lens = context_lens.data_ptr<int>(),
.logits = logits.data_ptr(),
.block_table = block_table.data_ptr<int>(),
.schedule_meta = schedule_meta.data_ptr<int>(),
.tensor_map_q = tensor_map_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_kv_scales = tensor_map_kv_scales,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_sms,
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SMXXFP8PagedMQALogitsRuntime::generate(args);
const auto runtime = compiler->build("smxx_fp8_paged_mqa_logits", code);
SMXXFP8PagedMQALogitsRuntime::launch(runtime, args);
}
class SM100FP4PagedMQALogitsRuntime final: public LaunchRuntime<SM100FP4PagedMQALogitsRuntime> {
public:
struct Args {
int batch_size;
int next_n;
int num_heads;
int head_dim;
int block_kv;
bool is_context_lens_2d;
int block_table_stride;
int logits_stride;
int num_q_stages;
int num_kv_stages;
int split_kv;
int* context_lens;
void* logits;
int* block_table;
int* schedule_meta;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_sf_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_sf_kv;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp4_paged_mqa_logits<
{}, {},
{}, {},
{},
{}, {},
{},
{}, {},
{}
>);
}};
)", args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.is_context_lens_2d,
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
args.logits_stride, args.block_table_stride,
args.context_lens, args.logits,
args.block_table, args.schedule_meta,
args.tensor_map_q, args.tensor_map_sf_q,
args.tensor_map_kv, args.tensor_map_sf_kv,
args.tensor_map_weights
));
}
};
static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& sf_q,
const torch::Tensor& kv_cache,
const torch::Tensor& kv_cache_sf,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& logits,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const at::ScalarType& logits_dtype,
const int& batch_size, const int& next_n,
const int& num_heads, const int& head_dim,
const int& num_kv_blocks, const int& block_kv,
const bool& is_context_lens_2d,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
const int& split_kv) {
const int num_specialized_threads = 128;
const int num_math_threads = 2 * 128;
DG_HOST_ASSERT(split_kv == 256 and logits_stride % split_kv == 0);
// TODO: tuning num_stages
const int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3;
const int next_n_atom = (next_n % 2 == 0) ? 2 : 1;
// `head_dim` must be 128 for 64B swizzling
DG_HOST_ASSERT(head_dim == 128);
// Using 2D TMA as tensor q is asserted contiguous
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
head_dim, next_n_atom * num_heads,
static_cast<int>(q.stride(2)),
head_dim / 2, 0, false, false);
// NOTES: `sf_q` is a 3D tensor, while `weights` is a 2D tensor
const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, batch_size * next_n,
num_heads, next_n_atom,
static_cast<int>(sf_q.stride(1)), 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n,
num_heads, next_n_atom,
static_cast<int>(weights.stride(0)), 0);
const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
head_dim, block_kv, 1,
static_cast<int>(kv_cache.stride(1)),
static_cast<int>(kv_cache.stride(0)),
head_dim / 2, 0, false, false);
const auto tensor_map_sf_kv = make_tma_2d_desc(kv_cache_sf, block_kv, num_kv_blocks,
block_kv, 1,
static_cast<int>(kv_cache_sf.stride(0)), 0);
// Calculate shared memory size
const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim / 2;
const int smem_sf_q_size_per_stage = align(next_n_atom * num_heads, 128) * sizeof(int);
const int smem_kv_size_per_stage = split_kv * head_dim / 2;
const int smem_sf_kv_size_per_stage = align(split_kv, 128) * sizeof(int);
const int smem_weight_size_per_stage = next_n_atom * num_heads * sizeof(float);
const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8;
const int smem_tmem_ptr = 4;
const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) +
smem_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SM100FP4PagedMQALogitsRuntime::Args args = {
.batch_size = batch_size,
.next_n = next_n,
.num_heads = num_heads,
.head_dim = head_dim,
.block_kv = block_kv,
.is_context_lens_2d = is_context_lens_2d,
.block_table_stride = block_table_stride,
.logits_stride = logits_stride,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.split_kv = split_kv,
.context_lens = context_lens.data_ptr<int>(),
.logits = logits.data_ptr(),
.block_table = block_table.data_ptr<int>(),
.schedule_meta = schedule_meta.data_ptr<int>(),
.tensor_map_q = tensor_map_q,
.tensor_map_sf_q = tensor_map_sf_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_sf_kv = tensor_map_sf_kv,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_sms,
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SM100FP4PagedMQALogitsRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp4_paged_mqa_logits", code);
SM100FP4PagedMQALogitsRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -72,7 +72,7 @@ static void __instantiate_kernel() {{
class PackFP32IntoUE8M0Runtime final: public LaunchRuntime<PackFP32IntoUE8M0Runtime> {
public:
struct Args {
int num_groups, mn, sf_k, packed_sf_k;
int num_groups, mn, sf_k, packed_sf_k, gran_k;
int block_mn, block_packed_sf_k;
void *sf, *out, *ks;
@@ -95,32 +95,32 @@ static void __instantiate_kernel() {{
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k));
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k, args.gran_k));
}
};
static std::tuple<int, int, int, int, int, torch::Tensor> preprocess_sf(const torch::Tensor& sf) {
// NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
const auto& dim = sf.dim();
const auto dim = sf.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat);
const auto& batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
const auto batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
const auto& [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
const auto& tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
const auto [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
const auto tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf};
}
static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
// The last kernel already gives a column-major TMA aligned layout
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
const auto& out = torch::empty_strided({num_groups, mn, sf_k},
{tma_aligned_mn * sf_k, 1, tma_aligned_mn},
batched_sf.options());
const auto out = torch::empty_strided({num_groups, mn, sf_k},
{tma_aligned_mn * sf_k, 1, tma_aligned_mn},
batched_sf.options());
if (not batched_sf.is_contiguous()) {
// Fallback to PyTorch's slow copy if not contiguous
@@ -129,7 +129,7 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
} else {
constexpr int block_mn = 64;
constexpr int num_threads = 512;
const auto& smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast<int>(sizeof(float));
const auto smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast<int>(sizeof(float));
const TransposeFP32Runtime::Args& args = {
.mn = mn,
.sf_k = sf_k,
@@ -139,25 +139,25 @@ static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size)
};
const auto& code = TransposeFP32Runtime::generate(args);
const auto& runtime = compiler->build("transpose_fp32", code);
const auto code = TransposeFP32Runtime::generate(args);
const auto runtime = compiler->build("transpose_fp32", code);
TransposeFP32Runtime::launch(runtime, args);
}
return (dim == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) {
const auto& sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf;
const auto sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf;
// First, convert into UE8M0 `uint8_t`
const auto& ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8);
const auto ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8);
// Second, make padded packed tensors
const auto& [num_groups, mn, k] = get_shape<3>(sf_reshaped);
const auto& aligned_mn = get_tma_aligned_size(mn, 4);
const auto& aligned_k = align(k, 4);
const auto [num_groups, mn, k] = get_shape<3>(sf_reshaped);
const auto aligned_mn = get_tma_aligned_size(mn, 4);
const auto aligned_k = align(k, 4);
const auto& options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8);
const auto options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8);
auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options);
// ReSharper disable once CppExpressionWithoutSideEffects
padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor);
@@ -172,11 +172,11 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const to
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
const auto& [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto& packed_sf_k = ceil_div(sf_k, 4);
const auto& out = torch::empty_strided({num_groups, mn, packed_sf_k},
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto packed_sf_k = ceil_div(sf_k, 4);
const auto out = torch::empty_strided({num_groups, mn, packed_sf_k},
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
// Launch the kernel
if (batched_sf.is_contiguous()) {
if ((mn * sf_k) % 4 != 0 and num_groups > 1)
@@ -193,8 +193,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4)
};
const auto& code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
const auto code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
const auto runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
} else {
if (mn % 4 != 0 or num_groups > 1)
@@ -217,8 +217,8 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
const auto code = PackFP32IntoUE8M0Runtime::generate(args);
const auto runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
}
return (dim == 2) ? out.squeeze(0) : out;
@@ -226,18 +226,20 @@ static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::T
static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf,
const torch::Tensor& ks_tensor,
const std::vector<int>& ks) {
const auto& [sf_k, mn] = get_shape<2>(sf);
const auto& num_groups = static_cast<int>(ks.size());
const std::vector<int>& ks,
const int gran_k) {
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
const auto [sf_k, mn] = get_shape<2>(sf);
const auto num_groups = static_cast<int>(ks.size());
int ref_sf_k = 0, packed_sf_k = 0;
for (const auto& k: ks)
ref_sf_k += ceil_div(k, 128), packed_sf_k += ceil_div(k, 512);
for (const auto k: ks)
ref_sf_k += ceil_div(k, gran_k), packed_sf_k += ceil_div(k, gran_k * 4);
DG_HOST_ASSERT(sf.is_contiguous());
DG_HOST_ASSERT(ref_sf_k == sf_k);
DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0);
const auto& out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
const auto out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
@@ -247,6 +249,7 @@ static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(cons
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.gran_k = gran_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = sf.data_ptr(),
@@ -255,8 +258,8 @@ static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(cons
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto& code = PackFP32IntoUE8M0Runtime::generate(args);
const auto& runtime = compiler->build("pack_fp32_into_ue8m0", code);
const auto code = PackFP32IntoUE8M0Runtime::generate(args);
const auto runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
return out;
}

View File

@@ -6,6 +6,7 @@
#include "apis/hyperconnection.hpp"
#include "apis/gemm.hpp"
#include "apis/layout.hpp"
#include "apis/mega.hpp"
#include "apis/runtime.hpp"
#ifndef TORCH_EXTENSION_NAME
@@ -22,5 +23,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
deep_gemm::hyperconnection::register_apis(m);
deep_gemm::gemm::register_apis(m);
deep_gemm::layout::register_apis(m);
deep_gemm::mega::register_apis(m);
deep_gemm::runtime::register_apis(m);
}

View File

@@ -42,7 +42,7 @@ do { \
#ifndef DG_NVRTC_CHECK
#define DG_NVRTC_CHECK(cmd) \
do { \
const auto& e = (cmd); \
const auto e = (cmd); \
if (e != NVRTC_SUCCESS) { \
throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \
} \
@@ -52,7 +52,7 @@ do { \
#ifndef DG_CUDA_DRIVER_CHECK
#define DG_CUDA_DRIVER_CHECK(cmd) \
do { \
const auto& e = (cmd); \
const auto e = (cmd); \
if (e != CUDA_SUCCESS) { \
std::stringstream ss; \
const char *name, *info; \
@@ -66,7 +66,7 @@ do { \
#ifndef DG_CUDA_RUNTIME_CHECK
#define DG_CUDA_RUNTIME_CHECK(cmd) \
do { \
const auto& e = (cmd); \
const auto e = (cmd); \
if (e != cudaSuccess) { \
std::stringstream ss; \
ss << static_cast<int>(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \
@@ -97,7 +97,7 @@ inline const char* cublasGetStatusString(cublasStatus_t status) {
#define DG_CUBLASLT_CHECK(cmd) \
do { \
const auto& e = (cmd); \
const auto e = (cmd); \
if (e != CUBLAS_STATUS_SUCCESS) { \
std::ostringstream ss; \
ss << static_cast<int>(e) << " (" << cublasGetStatusString(e) << ")"; \

View File

@@ -6,7 +6,7 @@ namespace deep_gemm {
static uint64_t fnv1a(const std::vector<char>& data, const uint64_t& seed) {
uint64_t h = seed;
const uint64_t& prime = 0x100000001b3ull;
const uint64_t prime = 0x100000001b3ull;
for (const char& c: data) {
h ^= static_cast<uint8_t>(c);
h *= prime;
@@ -15,11 +15,11 @@ static uint64_t fnv1a(const std::vector<char>& data, const uint64_t& seed) {
}
static std::string get_hex_digest(const std::vector<char>& data) {
const auto& state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull);
const auto& state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull);
const auto state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull);
const auto state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull);
// Split-mix 64
const auto& split_mix = [](uint64_t z) {
const auto split_mix = [](uint64_t z) {
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull;
z = (z ^ (z >> 27)) * 0x94d049bb133111ebull;
return z ^ (z >> 31);

View File

@@ -116,9 +116,4 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf,
return sf;
}
// Value matrix layout
static int get_mk_alignment_for_contiguous_layout() {
return 128;
}
} // namespace deep_gemm

View File

@@ -1,3 +1,4 @@
// TODO: merge this file with `math.cuh` (the device part)
#pragma once
#include <torch/python.h>
@@ -6,8 +7,8 @@
namespace deep_gemm {
// TODO: Use `torch::kFloat4_e2m1fn_x2`
constexpr auto kPackedFP4 = torch::kUInt8;
// TODO: use `torch::kFloat4_e2m1fn_x2`
constexpr auto kPackedFP4 = torch::kInt8;
template <typename T>
static T ceil_div(const T& a, const T& b) {

View File

@@ -16,7 +16,7 @@ namespace deep_gemm {
// ReSharper disable once CppNotAllPathsReturnValue
template <typename dtype_t>
static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) {
const auto& c_str = std::getenv(name.c_str());
const auto c_str = std::getenv(name.c_str());
if (c_str == nullptr)
return default_value;
@@ -34,7 +34,7 @@ static dtype_t get_env(const std::string& name, const dtype_t& default_value = d
static std::tuple<int, std::string> call_external_command(std::string command) {
command = command + " 2>&1";
const auto& deleter = [](FILE* f) { if (f) pclose(f); };
const auto deleter = [](FILE* f) { if (f) pclose(f); };
std::unique_ptr<FILE, decltype(deleter)> pipe(popen(command.c_str(), "r"), deleter);
DG_HOST_ASSERT(pipe != nullptr);
@@ -42,7 +42,10 @@ static std::tuple<int, std::string> call_external_command(std::string command) {
std::string output;
while (fgets(buffer.data(), buffer.size(), pipe.get()))
output += buffer.data();
const auto& exit_code = WEXITSTATUS(pclose(pipe.release()));
const auto status = pclose(pipe.release());
// NOTES: if the child was killed by a signal (e.g., SIGINT from Ctrl+C),
// WEXITSTATUS would incorrectly return 0. Treat signal death as failure.
const auto exit_code = WIFEXITED(status) ? WEXITSTATUS(status) : 128 + WTERMSIG(status);
return {exit_code, output};
}
@@ -68,7 +71,7 @@ static std::vector<std::filesystem::path> collect_files(const std::filesystem::p
static std::filesystem::path make_dirs(const std::filesystem::path& path) {
// OK if existed
std::error_code capture;
const bool& created = std::filesystem::create_directories(path, capture);
const bool created = std::filesystem::create_directories(path, capture);
if (not (created or capture.value() == 0)) {
DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}",
path.c_str(), created, capture.value()));
@@ -94,4 +97,32 @@ static std::string get_uuid() {
return ss.str();
}
static void safe_remove_all(const std::filesystem::path& path) {
std::error_code ec;
if (not std::filesystem::exists(path, ec) or ec)
return;
// A single file
if (not std::filesystem::is_directory(path, ec) or ec) {
std::filesystem::remove(path, ec);
return;
}
// Remove directory
auto it = std::filesystem::directory_iterator(path,
std::filesystem::directory_options::skip_permission_denied, ec);
for (auto end = std::filesystem::directory_iterator(); it != end and not ec;) {
const auto entry_path = it->path();
// Increase firstly to avoid failures
it.increment(ec);
if (ec)
break;
// Recursively clean
safe_remove_all(entry_path);
}
std::filesystem::remove(path, ec);
}
} // deep_gemm

View File

@@ -19,6 +19,10 @@ from ._C import (
get_num_sms,
set_tc_util,
get_tc_util,
set_ignore_compile_dims,
set_block_size_multiple_of,
set_pdl,
get_pdl,
)
# cuBLASLt Kernels
@@ -56,14 +60,16 @@ try:
einsum,
fp8_einsum,
# Attention kernels
fp8_mqa_logits,
fp8_fp4_mqa_logits,
get_paged_mqa_logits_metadata,
fp8_fp4_paged_mqa_logits,
# Attention kernels (legacy)
fp8_mqa_logits,
fp8_paged_mqa_logits,
# Hyperconnection kernels
tf32_hc_prenorm_gemm,
# Layout kernels
transform_sf_into_required_layout,
get_mk_alignment_for_contiguous_layout
)
# Some alias for legacy supports
@@ -74,6 +80,14 @@ except ImportError:
# Expected behavior for CUDA runtime version before 12.1
pass
# Mega kernels
from .mega import (
SymmBuffer,
get_symm_buffer_for_mega_moe,
transform_weights_for_mega_moe,
fp8_fp4_mega_moe,
)
# Some utils
from . import testing
from . import utils
@@ -109,4 +123,4 @@ _C.init(
_find_cuda_home() # CUDA home
)
__version__ = '2.3.0'
__version__ = '2.4.2'

View File

@@ -0,0 +1,74 @@
#pragma once
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/layout/sym_buffer.cuh>
#include <deep_gemm/layout/mega_moe.cuh>
namespace deep_gemm::comm {
template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
const uint32_t& sm_idx, const uint32_t& thread_idx,
const sync_scope_t& sync_scope) {
// NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()`
static constexpr uint32_t kFinishSumTag = 0x80000000u;
sync_scope();
if (thread_idx == 0) {
const auto count_ptr = workspace.get_grid_sync_count_ptr<kGridSyncIndex>();
const auto old_value = ptx::atomic_add_rel(
count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1);
uint32_t new_value;
do {
new_value = ptx::ld_acq(count_ptr);
} while (((new_value ^ old_value) & kFinishSumTag) == 0);
}
sync_scope();
}
template <uint32_t kNumRanks, uint32_t kNumSMs, uint32_t kNumThreads, uint32_t kGridSyncIndex, uint32_t kTag, typename sync_scope_t>
CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace,
const layout::SymBuffer<kNumRanks>& sym_buffer,
const uint32_t& sm_idx, const uint32_t& thread_idx,
const sync_scope_t& sync_scope,
const bool& sync_prologue = true,
const bool& sync_epilogue = true) {
DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads");
// Grid sync before NVLink signaling
if (sync_prologue)
grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
// NVLink cross-rank barrier, only SM 0 participates
if (sm_idx == 0) {
auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr();
const auto status = (*counter_ptr) & 3;
const auto signal_phase = status & 1, signal_sign = status >> 1;
auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase);
// Send signals to remote ranks
if (thread_idx < kNumRanks)
ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1);
sync_scope();
// Update status and wait arrival (with 30s timeout, at 2 GHz)
constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll;
if (thread_idx == 0) {
ptx::red_add(counter_ptr, 1);
const int target = signal_sign ? 0 : static_cast<int>(kNumRanks);
const auto start_clock = clock64();
while (ptx::ld_acq_sys(signal_ptr) != target) {
if (clock64() - start_clock >= kNumTimeoutCycles) {
printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n",
sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag);
DG_DEVICE_ASSERT(false and "NVLink barrier timeout");
}
}
}
}
// Grid sync after NVLink completion
if (sync_epilogue)
grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
}
} // namespace deep_gemm::comm

View File

@@ -0,0 +1,18 @@
#pragma once
#include <cutlass/detail/helper_macros.hpp>
#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__)
#define DG_IN_CUDA_COMPILATION
#endif
#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__))
#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__
#define CUTLASS_DEVICE_NOINLINE __device__
#elif defined(__CUDACC_RTC__)
#define CUTLASS_HOST_DEVICE_NOINLINE __device__
#define CUTLASS_DEVICE_NOINLINE __device__
#else
#define CUTLASS_HOST_DEVICE_NOINLINE
#define CUTLASS_DEVICE_NOINLINE
#endif

View File

@@ -1,5 +1,7 @@
#pragma once
#include <cute/int_tuple.hpp>
namespace cute {
struct ignore_t {

View File

@@ -0,0 +1,43 @@
#pragma once
#include <cuda/std/cstdint>
#include <deep_gemm/common/compile.cuh>
#ifdef __CLION_IDE__
CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) {
asm volatile("trap;");
}
#define printf host_device_printf
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#endif
#ifndef DG_UNIFIED_ASSERT
#ifdef DG_IN_CUDA_COMPILATION
#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond)
#else
#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond)
#endif
#endif

View File

@@ -0,0 +1,149 @@
#pragma once
#include <cuda/std/cstdint>
#include <deep_gemm/common/compile.cuh>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::math {
/// Pointer operations
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) {
return reinterpret_cast<dtype_t*>(static_cast<uint8_t*>(ptr) + num_bytes);
}
/// Math functions
template <typename T>
CUTLASS_HOST_DEVICE T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T, bool kDoCeilAlignment = true>
CUTLASS_HOST_DEVICE T align(T a, T b) {
return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b;
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) {
return constexpr_ceil_div(a, b) * b;
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : constexpr_gcd(b, a % b);
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
template <typename T>
CUTLASS_DEVICE void swap(T& a, T& b) {
T temp = a;
a = b;
b = temp;
}
#ifdef DG_IN_CUDA_COMPILATION
CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) {
#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)
return __ffma2_rn(a, b, c);
#else
return make_float2(
__fmaf_rn(a.x, b.x, c.x),
__fmaf_rn(a.y, b.y, c.y)
);
#endif
}
CUTLASS_HOST_DEVICE float fast_rcp(const float& x) {
float ret;
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
/// Casting
template <typename old_t>
CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) {
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
return *reinterpret_cast<int*>(&bf16x2);
}
CUTLASS_DEVICE float fast_pow2(const int& x) {
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
CUTLASS_DEVICE int fast_log2_ceil(float x) {
const auto bits = *reinterpret_cast<uint32_t*>(&x);
const auto exp = bits >> 23;
const auto man = bits & ((1 << 23) - 1);
return exp - 127 + (man != 0);
}
template <bool kUseUE8M0 = true>
CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) {
DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0");
const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0};
const auto scaled = __fmul2_rn(amax, finfo_factor);
const auto exp_x = fast_log2_ceil(scaled.x);
const auto exp_y = fast_log2_ceil(scaled.y);
sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x);
sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y);
}
/// Reduction
CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) {
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset);
if (lane_idx >= offset)
value += synced;
}
return value;
}
// Operation functors
template <typename T> struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } };
template <typename T> struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } };
template <typename T> struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } };
// Unified reduction function
template <uint32_t kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
CUTLASS_DEVICE T warp_reduce(T value, Op op) {
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
} else {
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
}
return value;
}
// Convenience aliases
template <uint32_t kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
CUTLASS_DEVICE T warp_reduce_sum(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}
#endif
} // namespace deep_gemm

View File

@@ -0,0 +1,92 @@
#pragma once
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/arch/copy_sm100_tma.hpp>
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::tma {
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
constexpr uint32_t get_inner_block_atom_size() {
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
}
template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
uint32_t kSwizzleMode,
typename dtype_t, bool kIs3DTMA = false>
CUTLASS_DEVICE void
copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx,
const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) {
DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
if constexpr (not kIs3DTMA) {
if (num_tma_multicast == 1) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
} else {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
// 2-CTA function will send signals to the leader CTA only
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
if (cute::block_rank_in_cluster() == 0) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
}
#endif
}
} else {
if (num_tma_multicast == 1) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
} else {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
// 2-CTA function will send signals to the leader CTA only
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
if (cute::block_rank_in_cluster() == 0) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
}
#endif
}
}
}
} // namespace deep_gemm::tma

View File

@@ -0,0 +1,43 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
namespace deep_gemm {
enum class MmaKind {
BF16 = 0,
MXFP8FP4 = 1,
};
constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) {
switch (mma_kind) {
case MmaKind::BF16: return 2;
case MmaKind::MXFP8FP4: return 1;
default: return 0;
}
}
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
Batched = 4,
MGroupedContiguousWithPsumLayout = 5,
};
constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) {
switch (gemm_type) {
case GemmType::MGroupedContiguous: return true;
case GemmType::MGroupedContiguousWithPsumLayout: return true;
default: return false;
}
}
enum class KernelType {
Kernel1D1D = 0,
Kernel1D2D = 1,
KernelNoSF = 2
};
} // namespace deep_gemm

View File

@@ -1,167 +1,24 @@
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda/std/cstdint>
#include <cuda/std/utility>
#include <cute/container/tuple.hpp>
#include "cute_tie.cuh"
#include <deep_gemm/common/exception.cuh>
#ifdef __CLION_IDE__
__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) {
asm volatile("trap;");
}
#define printf host_device_printf
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#endif
namespace deep_gemm {
namespace deep_gemm::utils {
template <typename FuncT>
struct PatternVisitor {
FuncT func;
__device__ __host__
CUTLASS_HOST_DEVICE
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
__device__ __host__
auto operator [](const uint32_t& i) {
CUTLASS_HOST_DEVICE
auto operator [](const uint32_t& i) const {
return func(i);
}
};
template <typename T>
__device__ __host__ T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
__device__ __host__ T align(T a, T b) {
return ceil_div(a, b) * b;
}
template <typename T>
__device__ __host__ constexpr T constexpr_align(T a, T b) {
return constexpr_ceil_div(a, b) * b;
}
template <typename T>
__device__ __host__ constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : constexpr_gcd(b, a % b);
}
template<typename T>
__forceinline__ __device__ void swap(T& a, T& b) {
T temp = a;
a = b;
b = temp;
}
__forceinline__ __device__ uint32_t get_sm_idx() {
uint32_t sm_idx;
asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
return sm_idx;
}
__forceinline__ __device__ uint32_t get_lane_idx() {
uint32_t lane_id;
asm ("mov.u32 %0, %laneid;" : "=r"(lane_id));
return lane_id;
}
__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) {
uint32_t ret;
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
__device__ __forceinline__ float2 ld_shared(const float2* ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
__device__ __forceinline__ float4 ld_shared(const float4* ptr) {
float4 ret;
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) {
uint4 ret;
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
__device__ __forceinline__ float ld_shared(const float* ptr) {
float ret;
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
__device__ __forceinline__ void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val));
}
__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) {
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y));
}
__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val));
}
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) {
asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y));
}
__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
}
__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) {
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
}
template <typename old_t>
__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) {
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
return *reinterpret_cast<int*>(&bf16x2);
}
__device__ __forceinline__ void prefetch_l1(void *ptr) {
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
}
template <uint32_t kNumBytes>
struct Vectorized {
static auto zeros() {
@@ -180,4 +37,14 @@ struct Vectorized {
using vec_t = decltype(zeros());
};
} // namespace `deep_gemm`
template <uint32_t kNumCols>
CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() {
DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
if constexpr (kNumCols <= 32) return 32;
if constexpr (kNumCols <= 64) return 64;
if constexpr (kNumCols <= 128) return 128;
if constexpr (kNumCols <= 256) return 256;
return 512;
}
} // namespace deep_gemm::utils

View File

@@ -0,0 +1,137 @@
#pragma once
#include <cute/atom/copy_traits_sm100.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
namespace deep_gemm::epilogue {
template <uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
uint32_t kSwizzleCDMode,
uint32_t kNumTMAStoreStages,
uint32_t kNumUMMAStoreThreads,
GemmType kGemmType, bool kWithAccumulation,
typename cd_dtype_t,
typename epilogue_type_t,
typename pattern_cd_t>
CUTLASS_DEVICE void
sm100_store_cd(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
const uint32_t& tmem_base_addr,
const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
const cute::TmaDescriptor& tensor_map_cd) {
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Share store pipeline between blocks
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Iterate over M waves
constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M;
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]);
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = base_m_idx + w * STORE_BLOCK_M;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(base_n_idx + s * STORE_BLOCK_N);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = tmem_base_addr + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = smem_base_ptr + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
ptx::st_shared(
smem_ptr,
math::cast_into_bf16_and_pack(values[0], values[1]),
math::cast_into_bf16_and_pack(values[2], values[3]),
math::cast_into_bf16_and_pack(values[4], values[5]),
math::cast_into_bf16_and_pack(values[6], values[7])
);
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
ptx::tcgen05_before_thread_sync();
tmem_empty_barrier->arrive(0u);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kGemmType == GemmType::Batched) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx);
} else {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx);
}
cute::tma_store_arrive();
}
__syncwarp();
}
}
}
} // namespace deep_gemm::epilogue

View File

@@ -0,0 +1,144 @@
#pragma once
#include <cute/atom/copy_traits_sm100.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
namespace deep_gemm::epilogue {
template <uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
uint32_t kSwizzleCDMode,
uint32_t kNumTMAStoreStages,
uint32_t kNumUMMAStoreThreads,
GemmType kGemmType, bool kWithAccumulation,
typename cd_dtype_t,
typename epilogue_type_t,
typename pattern_cd_t>
CUTLASS_DEVICE void
sm100_store_cd_swap_ab(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
const uint32_t& tmem_base_addr,
const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
const uint32_t& effective_m,
const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
const cute::TmaDescriptor& tensor_map_cd) {
// NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows,
// implying STORE_BLOCK_N must be 128.
DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows");
// TMA checks
constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumSwizzleAtomRows = 8;
DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled");
DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling");
DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling");
// Share store pipeline between blocks
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Iterate over M blocks
const auto num_stores = effective_m / STORE_BLOCK_M;
for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) {
uint32_t tmem_addr = tmem_base_addr +
s * STORE_BLOCK_M + // Store stage offset
i * kNumSwizzleAtomRows; // In-block offset
uint32_t values[kNumSwizzleAtomRows];
// Warps cooperatively write an atomic block to shared memory
DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes");
constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32;
uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode;
uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode;
auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset;
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// NOTES: Swizzling is not required in this case, but used here for consistency with other cases
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
uint32_t col = lane_idx / 4;
#pragma unroll
for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) {
auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
+ (col ^ row) * kNumBankGroupBytes
+ (lane_idx % 4) * sizeof(float);
ptx::st_shared(reinterpret_cast<uint32_t*>(smem_ptr), values[row]);
}
} else {
// Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements
// Start from lane index 0
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
// Start from lane index 16
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
// Destination shared memory address
uint32_t row = lane_idx % 8;
uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8;
auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
+ (col ^ row) * kNumBankGroupBytes;
// Store matrix with transposition
ptx::SM90_U32x4_STSM_T<int>::copy(math::cast_into_bf16_and_pack(values[0], values[1]),
math::cast_into_bf16_and_pack(values[2], values[3]),
math::cast_into_bf16_and_pack(values[4], values[5]),
math::cast_into_bf16_and_pack(values[6], values[7]),
smem_ptr);
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (s == num_stores - 1) {
ptx::tcgen05_before_thread_sync();
tmem_empty_barrier->arrive(0u);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) {
auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM;
uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M;
uint32_t n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N_ATOM>(base_n_idx + i * STORE_BLOCK_N_ATOM);
// Issue 2D or 3D TMA store
if constexpr (kGemmType == GemmType::Batched) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx);
} else {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx);
}
}
cute::tma_store_arrive();
}
__syncwarp();
}
}
} // namespace deep_gemm::epilogue

View File

@@ -0,0 +1,24 @@
#pragma once
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::epilogue::transform {
struct EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
return n_idx;
}
};
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
struct EpilogueHeadSplits: EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and
kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
}
};
} // namespace deep_gemm::epilogue::transform

View File

@@ -4,14 +4,18 @@
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/epilogue/sm100_store_cd.cuh>
#include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
@@ -21,9 +25,10 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
bool kSwapAB,
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
uint64_t kTensorCoreUtilControl>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_bf16_gemm_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
@@ -48,41 +53,31 @@ sm100_bf16_gemm_impl(int* grouped_layout,
if constexpr (kWithAccumulation)
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
// Configs
// MMA Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
constexpr uint32_t kNumTMAStoreStages = 2;
DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// 2-CTA MMA
constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
constexpr uint32_t UMMA_K = 16;
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
(not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
// Epilogue configs
// Always enable pipeline for better performance
constexpr uint32_t kNumEpilogueStages = 2;
constexpr uint32_t kNumTMAStoreStages = 2;
// NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
// per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
@@ -91,50 +86,60 @@ sm100_bf16_gemm_impl(int* grouped_layout,
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
// NOTES: Make sure we have enough shared memory for UMMA padding
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2;
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA");
// Real tensor memory size and offsets
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols>();
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols>();
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Synchronize the cluster before 2-CTA TMEM allocation
kNumMulticast > 1 ? cute::cluster_sync() : void();
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
if (warp_idx == 0) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_cd);
}
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// D/A/B shared memory
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
if (kNumMulticast > 1)
cute::cluster_sync();
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
@@ -162,9 +167,13 @@ sm100_bf16_gemm_impl(int* grouped_layout,
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
@@ -181,16 +190,20 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
// Use dynamic load block M, when swap-AB is enabled
const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
// For k-grouped layout, the number of block K is variable
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
@@ -198,14 +211,14 @@ sm100_bf16_gemm_impl(int* grouped_layout,
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
@@ -213,16 +226,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
// Arrive at full barriers
@@ -238,17 +251,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
// TODO: refactor `UMMA_M` calculation
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
UMMA_M, UMMA_N, kMajorB, kMajorA>()
: cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
// Merged stages only happens in NT normal GEMM cases
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
@@ -265,7 +277,7 @@ sm100_bf16_gemm_impl(int* grouped_layout,
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// UMMA and empty barrier arrival alias
auto umma_arrive = [](const uint64_t* barrier) {
@@ -282,36 +294,45 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
__syncwarp();
};
// Dynamic update of UMMA N based on effective M, when swap-AB is enabled
if constexpr (kSwapAB) {
uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
}
// Launch MMAs
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_F16BF16_SS, SM100_MMA_F16BF16_2x1SM_SS>;
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
using mma_t = cute::conditional_t<kNumMulticast == 1, ptx::SM100_MMA_F16BF16_SS, ptx::SM100_MMA_F16BF16_2x1SM_SS>;
const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
runtime_instr_desc);
a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(
a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(
b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
if (kSwapAB) {
mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc);
} else {
mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc);
}
}
}
__syncwarp();
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
@@ -322,15 +343,16 @@ sm100_bf16_gemm_impl(int* grouped_layout,
if constexpr (kTensorCoreUtilControl < 100) {
// For utilization control
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
__syncwarp();
// Wait for last UMMA to be done
tensor_core_full_barrier->wait(tensor_core_phase);
tensor_core_phase ^= 1;
// Sleep for certain cycles
constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull;
constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull;
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
const auto& start_clock = clock64();
const auto start_clock = clock64();
if (cute::elect_one_sync())
while (clock64() - start_clock < kNumDummyCycles) {}
__syncwarp();
@@ -339,9 +361,9 @@ sm100_bf16_gemm_impl(int* grouped_layout,
}
// To safely deconstruct barriers, we need another round of waits
const auto& iter_idx = scheduler.current_iter - 1;
const auto iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
@@ -351,19 +373,10 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
@@ -372,106 +385,44 @@ sm100_bf16_gemm_impl(int* grouped_layout,
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
const auto base_m_idx = scheduler.template get_global_idx<
(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto base_n_idx = n_block_idx * BLOCK_N;
// Iterate over M waves
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
__syncwarp();
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kGemmType == GemmType::Batched) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
n_idx, m_idx, scheduler.current_group_idx);
} else {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
}
cute::tma_store_arrive();
}
}
if constexpr (kSwapAB) {
const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
epilogue::sm100_store_cd_swap_ab<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue::transform::EpilogueIdentity>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
effective_m,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
} else {
epilogue::sm100_store_cd<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue::transform::EpilogueIdentity>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
}
}
}
// Deallocate tensor memory
// TODO: Remove redundant synchronization
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Deallocate tensor memory
if (warp_idx == 0)
Allocator().free(0, kNumTmemCols);

View File

@@ -5,18 +5,19 @@
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumThreads>
__global__ void __launch_bounds__(kNumThreads, 1)
CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1)
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
@@ -30,7 +31,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
const auto lane_idx = ptx::get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
@@ -51,24 +52,24 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
}
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_N>();
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_N>();
// Fill D/A/B
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
// Fill the tensor memory pointer
@@ -93,14 +94,17 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
__syncthreads();
// Block indices
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx == 0) {
// TMA load warp
for (uint32_t s = 0; s < num_total_stages; ++ s) {
@@ -115,8 +119,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
// Issue TMAs
if (cute::elect_one_sync()) {
tma_copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
tma::copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
}
// Arrive at full barriers
@@ -134,8 +138,8 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
auto b_desc = make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
@@ -147,14 +151,14 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
"Invalid MMA instruction shape");
// Wait tensor memory empty barrier arrival
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Issue UMMA in the leader CTA
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
@@ -163,9 +167,11 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
a_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(b_desc_base_lo, 0, k * UMMA_K);
SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
a_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(
a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(
b_desc_base_lo, 0, k * UMMA_K);
ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
}
}
@@ -180,7 +186,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
if (warp_idx == 2)
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
@@ -191,7 +197,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
// Wait UMMA arrival
tmem_full_barrier->wait(0);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
@@ -239,7 +245,7 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
}
// Synchronize all threads and issue TMA
@@ -251,7 +257,6 @@ sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
}
}
__syncthreads();
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is doing TMA stores
if (warp_idx == 1)

View File

@@ -0,0 +1,456 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <uint32_t kNumHeads, uint32_t kHeadDim,
bool kIsCompressedLogits,
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t kNumSMs,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const uint32_t max_seqlen_k,
const uint32_t logits_stride,
const uint32_t* cu_seq_len_k_start,
const uint32_t* cu_seq_len_k_end,
logits_dtype_t* logits,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_sf_q);
cute::prefetch_tma_descriptor(&tensor_map_weights);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
}
// UMMA configs
static constexpr uint32_t kNumTmemStages = 3;
static constexpr uint32_t kNumUTCCPAlignedElems = 128;
static constexpr uint32_t UMMA_M = 128;
static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
static constexpr uint32_t UMMA_K = 64;
static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems);
static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems);
static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads;
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`");
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q and KV data on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
});
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
});
const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
});
auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
+ SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
// Tensor memory configs
constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQ / 32 + kNumSFKV / 32>();
constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
// Initialize barriers
if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(1);
}
#pragma unroll
for (uint32_t i = 0; i < kNumTmemStages; ++i) {
full_tmem_barriers[i]->init(1);
empty_tmem_barriers[i]->init(128);
}
cutlass::arch::fence_barrier_init();
}
// Allocate tensor memory
if (warp_idx == kSpecWarpStart + 2)
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
__syncthreads();
// Scheduler
const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple<uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv);
seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv);
start = cute::min(start, seq_k_start[i]);
end = cute::max(end, seq_k_end[i]);
}
// TMA alignment requirements for SF KV
start = start / 4 * 4;
return {start, math::ceil_div(end - start, BLOCK_KV)};
};
// Make Q, KV and TMEM pipeline
auto make_pipeline = [](const uint32_t& num_stages) {
// Return current stage and phase, and advance pipeline by steps
return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
uint32_t current_idx = iter_idx;
iter_idx += step;
return {current_idx % num_stages, (current_idx / num_stages) & 1};
};
};
auto advance_q_pipeline = make_pipeline(kNumQStages);
auto advance_kv_pipeline = make_pipeline(kNumKVStages);
auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 56;
constexpr uint32_t kNumMathRegisters = 224;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading Q
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
// Enumerate Q blocks
if (cute::elect_one_sync()) {
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Wait Q consumer release
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads);
tma::copy<BLOCK_Q * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q);
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q);
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
}
}
__syncwarp();
} else if (warp_idx == kSpecWarpStart + 1) {
// TMA warp for loading KV cache
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (cute::elect_one_sync()) {
// Enumerate Q blocks
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Load KV block ranges
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
// Enumerate KV blocks
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
// Wait KV consumer release
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
smem_sf_kv[kv_stage_idx],
kv_start + kv_idx * BLOCK_KV, 0);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
}
}
}
} else if (warp_idx == kSpecWarpStart + 2) {
// UMMA warp
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
// Enumerate Q blocks
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Load KV block ranges
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
// Wait TMA Q arrivals
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
// Transpose and copy SF Q
#pragma unroll
for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
if (cute::elect_one_sync())
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
__syncwarp();
}
// Enumerate KV blocks
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
// Wait TMA KV arrivals
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Transpose
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
}
// UMMA with SF
if (cute::elect_one_sync()) {
// Copy SF KV
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
// Wait TMEM release
CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
ptx::tcgen05_after_thread_sync();
// Issue UMMA with SF
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
// TODO: generalize umma desc
DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
auto a_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
auto b_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_q[q_stage_idx] + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
ptx::SM100_MMA_MXF4_SS::fma(
a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
}
// TODO: move this into `deep_gemm/ptx/tcgen05.cuh`
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
}
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
}
// UMMA warp must also arrive on empty_q to prevent running ahead
// of math warps in the Q pipeline. Without this, UMMA can consume
// kNumQStages Q blocks before math warps release any, causing a
// circular dependency: UMMA waits full_q -> TMA_Q waits empty_q
// -> Math waits full_tmem -> UMMA (already moved on).
empty_q_barriers[q_stage_idx]->arrive();
}
} else if (warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
const auto math_warpgroup_idx = warpgroup_idx;
const auto math_thread_idx = threadIdx.x;
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr uint32_t N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Math warpgroups process TMEM stages alternately
// Advance pipeline to align with the assigned stage
advance_tmem_pipeline(math_warpgroup_idx);
// Local register buffers
float accum[kNumHeads];
float weights[BLOCK_Q][kNumHeads];
// Enumerate Q blocks
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Load KV block ranges
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
// Wait TMA Q arrivals
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
// TODO: optimize bank conflicts
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
// Enumerate KV blocks
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
// Calculate KV offset in advance
auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx;
// Advance pipeline by `kNumMathWarpGroups` steps
// Wait UMMA arrival
CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
ptx::tcgen05_after_thread_sync();
// Reduce over the head dim and store
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
// Release TMEM empty
if (i == BLOCK_Q - 1) {
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
// TODO: optimize performance
const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast<uint64_t>(logits_stride);
if constexpr (kIsCompressedLogits) {
if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
logits[q_offset + kv_offset - seq_k_start[i]] = result;
} else {
logits[q_offset + kv_offset] = result;
}
__syncwarp();
}
}
// Release last Q empty
empty_q_barriers[q_stage_idx]->arrive();
}
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,496 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_sf_q);
cute::prefetch_tma_descriptor(&tensor_map_weights);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
}
// Next-N atom configs
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
// UMMA configs
static constexpr uint32_t kNumTmemStages = 3;
static constexpr uint32_t kNumUTCCPAlignedElems = 128;
static constexpr uint32_t UMMA_M = 128;
static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
static constexpr uint32_t UMMA_K = 64;
static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems);
static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems);
static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads;
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`");
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q and KV data on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
});
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
});
const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
});
auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
+ SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
// Tensor memory configs
constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQAtom / 32 + kNumSFKV / 32>();
constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
// Initialize barriers
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
cutlass::arch::fence_barrier_init();
}
if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(1);
}
cutlass::arch::fence_barrier_init();
}
if (warp_idx == kSpecWarpStart + 2) {
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumTmemStages; ++i) {
full_tmem_barriers[i]->init(1);
empty_tmem_barriers[i]->init(128);
}
cutlass::arch::fence_barrier_init();
}
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Make Q, KV and TMEM pipeline
auto make_pipeline = [](const uint32_t& num_stages) {
// Return current stage and phase, and advance pipeline by steps
return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
uint32_t current_idx = iter_idx;
iter_idx += step;
return {current_idx % num_stages, (current_idx / num_stages) & 1};
};
};
auto advance_q_pipeline = make_pipeline(kNumQStages);
auto advance_kv_pipeline = make_pipeline(kNumKVStages);
auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 56;
constexpr uint32_t kNumMathRegisters = 224;
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading Q
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (cute::elect_one_sync()) {
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
// Persistently schedule over blocks
// Initialize outside valid range to indicate no previous task
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, _, __;
while (scheduler.fetch_next_task(q_atom_idx, _, __)) {
// Issue TMA Q when (q_idx, atom_idx) changes
if (q_atom_idx != last_q_atom_idx) {
// Wait Q consumer release
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_q[q_stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_atom_idx * kNextNAtom);
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_atom_idx * kNextNAtom);
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
}
last_q_atom_idx = q_atom_idx;
}
}
__syncwarp();
} else if (warp_idx == kSpecWarpStart + 1) {
// TMA warp for loading KV cache
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
// Persistently schedule over blocks
uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, num_kv;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) {
// Reset block table cache on kv restart
if (q_atom_idx != last_q_atom_idx)
kv_block_idx_ptr = 32;
last_q_atom_idx = q_atom_idx;
// Coalesced load of block table
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
// Broadcast KV block indices
int kv_block_idx[kNumBlocksPerSplit];
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
kv_block_idx_ptr += kNumBlocksPerSplit;
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`");
// Wait KV consumer release
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
// Issue TMA KV
if (cute::elect_one_sync()) {
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i,
0, 0, kv_block_idx[i]);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
smem_sf_kv[kv_stage_idx] + BLOCK_KV * i,
0, kv_block_idx[i]);
}
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
}
}
} else if (warp_idx == kSpecWarpStart + 2) {
// UMMA warp
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
// Persistently schedule over blocks
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, _;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
// Wait TMA Q arrivals
uint32_t q_stage_idx, q_phase;
if (q_atom_idx != last_q_atom_idx) {
CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase);
// Release previous Q empty (UMMA warp must participate to prevent
// running ahead of math warps in the Q pipeline)
if (last_q_atom_idx != batch_size * kNumNextNAtoms)
empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
full_q_barriers[q_stage_idx]->wait(q_phase);
// Transpose and copy SF Q
#pragma unroll
for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
if (cute::elect_one_sync())
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
__syncwarp();
}
}
last_q_atom_idx = q_atom_idx;
// Wait TMA KV arrivals
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Transpose
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
}
// UMMA with SF
if (cute::elect_one_sync()) {
// Copy SF KV
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
// Wait TMEM release
CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
ptx::tcgen05_after_thread_sync();
// Issue UMMA with SF
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
// TODO: generalize UMMA desc
DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
auto a_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
auto b_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_q[q_stage_idx] + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
}
// TODO: move this PTX into headers
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
}
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
}
} else if (warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
const auto math_warpgroup_idx = warpgroup_idx;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr int N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Math warpgroups process TMEM stages alternately
// Advance pipeline to align with the assigned stage
advance_tmem_pipeline(math_warpgroup_idx);
// Local register buffers
float accum[kNumHeads];
float weights[kNextNAtom][kNumHeads];
// Persistently schedule over blocks
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, _;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
if (q_atom_idx != last_q_atom_idx) {
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
// Release last Q empty
if (last_q_atom_idx != batch_size * kNumNextNAtoms)
empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
// Wait TMA Q arrivals
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j));
weights[i][j + 0] = raw.x;
weights[i][j + 1] = raw.y;
weights[i][j + 2] = raw.z;
weights[i][j + 3] = raw.w;
}
}
}
last_q_atom_idx = q_atom_idx;
// Calculate KV offset in advance
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
// Advance pipeline by `kNumMathWarpGroups` steps
// Wait UMMA arrival
CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
ptx::tcgen05_after_thread_sync();
// Reduce over the head dim and store
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads>{}, tmem_addr, accum);
// Release TMEM empty
if (i == kNextNAtom - 1) {
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
const auto dst_offset = kv_offset + i * static_cast<uint64_t>(logits_stride);
if constexpr(sizeof(logits_dtype_t) == 2) {
// Pack two adjacent bf16 lanes into uint32 for wider store
uint16_t my_bits = *reinterpret_cast<const uint16_t*>(&result);
uint16_t neighbor_bits = __shfl_down_sync(0xffffffff, my_bits, 1);
uint32_t packed;
asm volatile("mov.b32 %0, {%1, %2};" : "=r"(packed) : "h"(my_bits), "h"(neighbor_bits));
if (lane_idx % 2 == 0)
*reinterpret_cast<uint32_t*>(logits + dst_offset) = packed;
} else {
logits[dst_offset] = result;
}
// this sync warp prevent the next load tmem from reordering
// nvcc may reorder it to overlap with the current tmem load, lead to large register usage
__syncwarp();
}
}
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,514 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/epilogue/sm100_store_cd.cuh>
#include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kGranKA, uint32_t kGranKB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
bool kSwapAB,
GemmType kGemmType, bool kWithAccumulation,
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
typename epilogue_type_t>
CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
// MMA Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
constexpr uint32_t UMMA_K = 32;
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
(not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
// SF configs
constexpr uint32_t kNumUTCCPAlignedElems = 128;
constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB");
// Epilogue configs
// Always enable pipeline for better performance
constexpr uint32_t kNumEpilogueStages = 2;
constexpr uint32_t kNumTMAStoreStages = 2;
// NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
// per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
"Shared memory of A/B must be aligned to 1024 bytes");
// NOTES: Make sure we have enough shared memory for UMMA padding
constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
// Tensor memory size and offsets
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Synchronize the cluster before 2-CTA TMEM allocation
kNumMulticast > 1 ? cute::cluster_sync() : void();
// Utils
const bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_sfb);
cute::prefetch_tma_descriptor(&tensor_map_cd);
}
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4);
const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4);
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// D/A/B shared memory
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// SFA/SFB shared memory
auto sf_start_ptr = reinterpret_cast<uint8_t*>(smem_b[kNumStages]);
auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
});
// Barriers and tensor memory pointer
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_sfb[kNumStages]);;
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
full_barriers[i]->init(1);
empty_barriers[i]->init(1);
// Arrive only at the leader CTA
with_sf_full_barriers[i]->init(kNumMulticast * 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
// Arrive at all CTAs
tmem_full_barriers[i]->init(1);
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs, kGranKA * 4>(
shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Use dynamic load block M, when swap-AB is enabled
const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
// For k-grouped layout, the number of block K is variable
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
uint32_t sfa_m_idx = m_block_idx * BLOCK_M;
uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>(
shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad));
tma::copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx);
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
uint32_t sfb_k_idx = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(
shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx);
tma::copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx);
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
}
// Arrive at full barriers
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled<b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorB, kMajorA>()
: cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
ptx::tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
__syncwarp();
};
// Dynamic update of UMMA N based on effective M, when swap-AB is enabled
if constexpr (kSwapAB) {
uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
}
// Launch MMAs
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
#pragma unroll 4
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[stage_idx]->wait(phase);
ptx::tcgen05_after_thread_sync();
const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx);
const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
if (cute::elect_one_sync()) {
// Do SF copy at certain stages
// TODO: process shared memory descriptor by addition
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
if (sfa_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
}
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
if (sfb_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
// Issue UMMA
using mma_t = cute::conditional_t<
kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>;
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
const auto runtime_instr_desc = kSwapAB ?
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id):
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
if constexpr (kSwapAB) {
mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc,
kTmemStartColOfSFB, kTmemStartColOfSFA);
} else {
mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc,
kTmemStartColOfSFA, kTmemStartColOfSFB);
}
}
}
__syncwarp();
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
}
}
// To safely deconstruct barriers, we need another round of waits
const auto iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx == 2) {
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
// Transpose for UTCCP at certain stages
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Arrive
with_sf_full_barriers[stage_idx]->arrive(0u);
}
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
// Epilogue warp groups
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
ptx::tcgen05_after_thread_sync();
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto base_n_idx = n_block_idx * BLOCK_N;
if constexpr (kSwapAB) {
const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
epilogue::sm100_store_cd_swap_ab<
BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue_type_t>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
effective_m,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
} else {
epilogue::sm100_store_cd<
BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue_type_t>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
}
}
}
// TODO: Remove redundant synchronization
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Deallocate tensor memory
if (warp_idx == 0)
Allocator().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

File diff suppressed because it is too large Load Diff

View File

@@ -6,27 +6,31 @@
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
using namespace deep_gemm::sm100;
template <uint32_t kNumHeads, uint32_t kHeadDim,
bool kIsCompressedLogits,
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t kNumSMs,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const uint32_t max_seqlen_k, const uint64_t stride_logits,
const uint32_t max_seqlen_k, const uint32_t stride_logits,
uint32_t* cu_seq_len_k_start,
uint32_t* cu_seq_len_k_end,
float* logits,
logits_dtype_t* logits,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
@@ -35,26 +39,26 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
// Q should be load only at once for a block
const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
// Types
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const auto& warp_in_group_idx = warp_idx % 4;
const auto& warpgroup_idx = warp_idx / 4;
const auto& lane_idx = get_lane_idx();
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
__syncwarp();
// Shared memory configs
// NOTES: weight may be unaligned
@@ -62,7 +66,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
// Align to 512 bytes for swizzle-64B
extern __shared__ __align__(512) uint8_t smem_buffer[];
@@ -75,19 +79,19 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
// Data on shared memory
auto smem_q = PatternVisitor([&](const uint32_t& i) {
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
});
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
@@ -95,76 +99,77 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// TMA barriers
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
// Tensor memory allocation
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
// Initialize barriers
DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32));
const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1));
if (is_tma_load_warp and cute::elect_one_sync()) {
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(kNumMathThreads);
}
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
full_umma_barriers[i]->init(1);
empty_umma_barriers[i]->init(128);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (is_umma_warp) {
}
if (warp_idx == kSpecWarpStart + 1) {
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
full_umma_barriers[i]->init(1);
empty_umma_barriers[i]->init(128);
}
cutlass::arch::fence_barrier_init();
}
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 24;
constexpr uint32_t kNumMathRegisters = 240;
constexpr uint32_t kNumSpecializedRegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Block scheduler
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + gridDim.x, q_iter_idx + 1};
uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + kNumSMs, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = cu_seq_len_k_start[q_idx];
seq_k_end[i] = cu_seq_len_k_end[q_idx];
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
// TMA alignment requirements for SF KV
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
start, ceil_div(end - start, BLOCK_KV)}; // Task info
start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
};
// KV pipeline
uint32_t num_total_kv_blocks = 0;
const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
return {
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
@@ -177,13 +182,16 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
if (is_tma_load_warp) {
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx == kSpecWarpStart) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
// Prefetch
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
};
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
@@ -209,10 +217,10 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
num_total_kv_blocks += num_kv_blocks;
@@ -221,11 +229,11 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
}
} else if (is_umma_warp) {
} else if (warp_idx == kSpecWarpStart + 1) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
// Require full allocation
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
@@ -252,12 +260,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_q[q_stage_idx], 0, k * UMMA_K);
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
}
@@ -266,23 +274,37 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
}
num_total_kv_blocks += num_kv_blocks;
// UMMA warp must also arrive on empty_q to prevent running ahead
// of math warps in the Q pipeline
empty_q_barriers[q_stage_idx]->arrive();
// Jump to the next block
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
} else if (warp_idx >= kNumMathThreads / 32) {
} else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kNumMathThreads / 32) {
} else if (warp_idx < kSpecWarpStart) {
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// Offsets
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
const auto& warp_offset = warp_idx * 32;
const auto& v_offset = lane_idx;
const auto tmem_start = warpgroup_idx * UMMA_N;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
// Preload weights
constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads);
float weights[BLOCK_Q][kNumWeightsInReg];
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr int N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Local register buffers
float weights[BLOCK_Q][kNumHeads];
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
@@ -293,9 +315,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// Read weights
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) {
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
// Compute over KV blocks
@@ -307,82 +329,59 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset);
float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
// Reduce over the head dim and store
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx;
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q;
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems");
uint32_t shifted_accum[kNumLDTMElems];
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
tcgen05_before_thread_sync();
empty_umma_barriers[warpgroup_idx]->arrive();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
// Load accumulator from TMEM
float accum[kNumHeads];
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Release TMEM empty
if (i == BLOCK_Q - 1) {
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[warpgroup_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (int j = 0; j < kNumWeightsInReg; j += 4) {
sum_0 = transform_reg(j, sum_0);
sum_1 = transform_reg(j + 2, sum_1);
}
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
sum_0 = transform_smem(j, sum_0);
sum_1 = transform_smem(j + 2, sum_1);
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
float result = scale_kv * (sum.x + sum.y);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
if constexpr (kIsCompressedLogits) {
if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result;
if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
logits[q_offset + kv_offset - seq_k_start[i]] = result;
} else {
logits[q_idx * stride_logits + kv_offset + v_offset] = result;
logits[q_offset + kv_offset] = result;
}
__syncwarp();
}
}
num_total_kv_blocks += num_kv_blocks;
@@ -393,12 +392,12 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// Jump to the next block
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
}
// Free tensor memory
__syncthreads();
if (is_tma_load_warp)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -6,28 +6,30 @@
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
using namespace deep_gemm::sm100;
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint64_t logits_stride, const uint64_t block_table_stride,
const uint32_t* context_lens, float* logits,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
@@ -35,27 +37,33 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const auto& warpgroup_idx = warp_idx / 4;
const auto& lane_idx = get_lane_idx();
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
__syncwarp();
// Next-N atom configs
static constexpr uint32_t kNextNAtom = (kNextN % 2 == 0) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = kNextN / kNextNAtom;
static constexpr bool kSingleAtom = (kNumNextNAtoms == 1);
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
@@ -63,43 +71,40 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q and KV data on shared memory
auto smem_q = PatternVisitor([&](const uint32_t& i) {
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
});
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups;
constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4);
const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4);
const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1);
// Initialize barriers
if (is_tma_load_warp and cute::elect_one_sync()) {
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
@@ -108,7 +113,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
}
cutlass::arch::fence_barrier_init();
}
if (is_umma_warp) {
if (warp_idx == kSpecWarpStart + 1) {
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
@@ -123,66 +128,76 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
constexpr uint32_t kNumSpecializedRegisters = 56;
constexpr uint32_t kNumMathRegisters = 224;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit>(batch_size, blockIdx.x, context_lens, schedule_meta);
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Q and KV pipeline
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
};
const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
};
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// UMMA settings
// Construct instruction with layout D
constexpr uint32_t UMMA_M = 128;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = kNextN * kNumHeads;
constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
if (is_tma_load_warp) {
// TMA warp-group for loading data
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_atom_idx) {
if (cute::elect_one_sync()) {
tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_atom_idx * kNextNAtom * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_atom_idx * kNextNAtom);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
}
};
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
uint32_t q_idx = batch_size, kv_idx, num_kv;
uint32_t next_q_idx, next_kv_idx, next_num_kv;
// Initialize outside valid range to indicate no previous task
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv;
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
bool fetched_next_task;
// Prefetch the first Q
if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
issue_tma_q(0, next_q_idx), q_iter_idx = 1;
if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)))
issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1;
int kv_block_idx_ptr = 32;
uint32_t kv_block_idx_ptr = 32;
uint32_t kv_block_idx_storage;
while (fetched_next_task) {
// Prefetch next Q when current Q changes
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
q_idx = next_q_idx;
// Prefetch next Q when (q, atom) changes
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + 1);
if (q_atom_idx != next_q_atom_idx)
kv_block_idx_ptr = 32;
q_atom_idx = next_q_atom_idx;
kv_idx = next_kv_idx;
num_kv = next_num_kv;
// Read KV block index
// TODO: deal with `-1`?
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
// TODO(xuzhean): consider -1
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0);
const auto block_table_offset = (q_atom_idx / kNumNextNAtoms) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
@@ -190,12 +205,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
if (prefetch_q) {
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
issue_tma_q(q_stage_idx, q_idx + 1);
issue_tma_q(q_stage_idx, q_atom_idx + 1);
}
int kv_block_idx[kNumBlocksPerSplit];
uint32_t kv_block_idx[kNumBlocksPerSplit];
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i)
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
kv_block_idx_ptr += kNumBlocksPerSplit;
@@ -205,45 +220,53 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
if (cute::elect_one_sync()) {
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
0, 0, 1, kv_block_idx[i]);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
0, kv_block_idx[i]);
for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) {
tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
0, 0, 1, kv_block_idx[i]);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
0, kv_block_idx[i]);
}
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
// Fetch next task
fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv);
}
} else if (is_umma_warp) {
} else if (warp_idx == kSpecWarpStart + 1) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Require full allocation
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
uint32_t q_idx = batch_size, kv_idx;
uint32_t next_q_idx, next_kv_idx, next_num_kv;
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 1;
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
if (q_idx != next_q_idx) {
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
if (q_atom_idx != next_q_atom_idx) {
// Release previous Q empty (UMMA warp must participate to prevent
// running ahead of math warps in the Q pipeline)
if (q_iter_idx > 0)
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
}
q_idx = next_q_idx;
q_atom_idx = next_q_atom_idx;
kv_idx = next_kv_idx;
// Wait KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
@@ -251,12 +274,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(umma_phase);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
auto b_desc = make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_q[q_stage_idx], 0, k * UMMA_K);
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
}
@@ -264,29 +287,45 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
}
umma_phase ^= 1;
}
} else if (is_math_warp) {
// Math warp-groups for WGMMA
} else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, context_lens, schedule_meta);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Offsets
const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0);
const uint32_t thread_idx = threadIdx.x;
const auto math_warpgroup_idx = warpgroup_idx;
const auto tmem_start = math_warpgroup_idx * UMMA_N;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
// Weights
constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads));
float weights[kNextN][kNumWeightsInReg];
DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers");
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr int N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
uint32_t q_idx = batch_size, kv_idx;
uint32_t next_q_idx, next_kv_idx, next_num_kv;
// Local register buffers
float weights[kNextNAtom][kNumHeads];
// Initialize outside valid range to indicate no previous task
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 0;
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
// Current Q changes
if (q_idx != next_q_idx) {
// Release Last Q empty
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
// Q or atom changes
if (q_atom_idx != next_q_atom_idx) {
// Release last Q empty
if (q_iter_idx > 0)
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
@@ -296,30 +335,30 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
for (uint32_t j = 0; j < kNumWeightsInReg; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
}
// Get current Q and KV index
q_idx = next_q_idx;
// Get current task indices
q_atom_idx = next_q_atom_idx;
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV;
auto kv_offset = q_atom_idx * kNextNAtom * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
// Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx);
float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait(umma_phase);
tcgen05_after_thread_sync();
full_umma_barriers[math_warpgroup_idx]->wait(umma_phase);
ptx::tcgen05_after_thread_sync();
umma_phase ^= 1;
// Release KV empty
@@ -327,72 +366,49 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
// Reduce over the head dim and store
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN;
uint32_t shifted_accum[kNumLDTMElems];
DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM");
auto tmem_load = [&](auto... Is) {
if constexpr (kNumLDTMElems == 32) {
cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 64) {
cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...);
} else if constexpr (kNumLDTMElems == 128) {
cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...);
}
};
[&]<size_t... Is>(cute::index_sequence<Is...>) { tmem_load(Is...); }(cute::make_index_sequence<kNumLDTMElems>{});
cutlass::arch::fence_view_async_tmem_load();
tcgen05_before_thread_sync();
empty_umma_barriers[warpgroup_idx]->arrive();
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
auto accum = reinterpret_cast<float*>(shifted_accum + i * kNumHeads);
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
// Load accumulator from TMEM
float accum[kNumHeads];
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Release TMEM empty
if (i == kNextNAtom - 1) {
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[math_warpgroup_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto& transform_reg = [&](const uint32_t& j, const float2& sum) {
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (int j = 0; j < kNumWeightsInReg; j += 4) {
sum_0 = transform_reg(j, sum_0);
sum_1 = transform_reg(j + 2, sum_1);
}
const auto& transform_smem = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j),
ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1));
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) {
sum_0 = transform_smem(j, sum_0);
sum_1 = transform_smem(j + 2, sum_1);
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
float result = scale_kv * (sum.x + sum.y);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
logits[kv_offset + i * logits_stride + thread_idx] = result;
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
__syncwarp();
}
}
} else {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
}
// Free tensor memory
__syncthreads();
if (is_umma_warp)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -4,20 +4,22 @@
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/reduction.cuh>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
__device__ __forceinline__
CUTLASS_DEVICE
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
// Calculate the index of the bank group to be written in the atom
const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
@@ -37,7 +39,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
@@ -58,7 +60,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
const auto lane_idx = ptx::get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
@@ -70,7 +72,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
@@ -82,20 +84,20 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
// Fill the tensor memory pointer
@@ -121,7 +123,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
}
__syncthreads();
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
@@ -131,6 +133,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const uint32_t m_offset = shape_m * k_split_idx;
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Dispatch warps into different roles
if (warp_idx < kNumMMAThreads / 32) {
// TMA load warp
@@ -145,8 +150,8 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
@@ -168,7 +173,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto b_desc = make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
@@ -185,7 +190,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const auto& stage_idx = s % kNumStages;
const auto& cast_stage_idx = s % kNumCastStages;
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Issue UMMA
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
@@ -194,7 +199,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
b_desc.lo = advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
}
@@ -218,7 +223,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
// Wait UMMA arrival
tmem_full_barrier->wait(0);
tcgen05_after_thread_sync();
ptx::tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
@@ -239,7 +244,7 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
if constexpr (BLOCK_M == 64)
__syncwarp();
}
@@ -290,9 +295,9 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; i += 2) {
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
uint32_values[0][i + 1], uint32_values[1][i + 1],
smem_ptr);
ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
uint32_values[0][i + 1], uint32_values[1][i + 1],
smem_ptr);
}
// Wait tensor memory empty
@@ -321,15 +326,15 @@ sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
cutlass::arch::fence_view_async_tmem_store();
// Arrive for issuing MMAs
tcgen05_before_thread_sync();
ptx::tcgen05_before_thread_sync();
full_cast_barriers[cast_stage_idx]->arrive();
}
// Intra-warp reduction and write back
#pragma unroll
for (uint32_t u = 0; u < 2; ++ u) {
const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y);
const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y);
const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
if (lane_idx % 4 == 0 and m_idx < shape_m)
sqr_sum[m_offset + m_idx] = reduced_sum;
}

View File

@@ -11,14 +11,19 @@
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/arch/mma_sm100_desc.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
@@ -30,7 +35,7 @@ template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kNumSMs,
GemmType kGemmType, bool kWithAccumulation,
typename cd_dtype_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bf16_gemm_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
@@ -51,7 +56,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
// Types
using WGMMA = typename BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
@@ -61,7 +66,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
@@ -71,7 +76,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_idx();
const uint32_t lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
@@ -88,17 +93,17 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// D/A/B shared memory
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
@@ -119,9 +124,12 @@ sm90_bf16_gemm_impl(int* grouped_layout,
constexpr uint32_t kNumTMARegisters = 48;
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
@@ -151,7 +159,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
@@ -159,31 +167,30 @@ sm90_bf16_gemm_impl(int* grouped_layout,
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Issue TMAs
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
tma::copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
}
}
@@ -203,8 +210,8 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// Merged stages only happens in NT normal GEMM cases
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
auto a_desc = make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
auto b_desc = make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
auto a_desc = mma::sm90::make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
auto b_desc = mma::sm90::make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
@@ -229,10 +236,10 @@ sm90_bf16_gemm_impl(int* grouped_layout,
};
// TODO: remove some useless computation for unaligned Ms
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
// Wait TMA arrivals
full_barriers[stage_idx]->wait(phase);
@@ -240,26 +247,26 @@ sm90_bf16_gemm_impl(int* grouped_layout,
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
a_desc.reg32_[0] = advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
b_desc.reg32_[0] = advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
}
}
warpgroup_commit_batch();
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(stage_idx);
@@ -324,7 +331,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
@@ -341,8 +348,8 @@ sm90_bf16_gemm_impl(int* grouped_layout,
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
}
}
}
@@ -350,7 +357,7 @@ sm90_bf16_gemm_impl(int* grouped_layout,
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
// Use TMA store to write back to global memory
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;

View File

@@ -4,26 +4,32 @@
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
float *d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Types
using WGMMA = typename BF16MMASelector<BLOCK_N>::type;
using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
@@ -33,7 +39,7 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_idx();
const uint32_t lane_idx = ptx::get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
@@ -48,17 +54,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
// Align to 1024 bytes for swizzle-128B
// Fill shared memory pointers
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
@@ -80,14 +86,17 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
constexpr uint32_t kNumMathRegisters = 232;
// Block indices
const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M);
const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
@@ -98,18 +107,18 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
#pragma unroll
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
const auto stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
const uint32_t& k_idx = sk_idx % SHAPE_K;
const uint32_t& s_idx = sk_idx / SHAPE_K;
const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
const uint32_t k_idx = sk_idx % SHAPE_K;
const uint32_t s_idx = sk_idx / SHAPE_K;
constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
tma_copy<BLOCK_K, BLOCK_M, kSwizzle>(
tma::copy<BLOCK_K, BLOCK_M, kSwizzle>(
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
tma_copy<BLOCK_K, BLOCK_N, kSwizzle>(
tma::copy<BLOCK_K, BLOCK_N, kSwizzle>(
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
}
@@ -125,32 +134,32 @@ sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrivals
const auto& stage_idx = s % kNumStages;
const auto stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, 1);
}
warpgroup_commit_batch();
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
empty_barriers[stage_idx]->arrive();
}
const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
if (col + i * 8 >= SHAPE_N)

View File

@@ -6,18 +6,26 @@
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/int_tuple.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tma.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
@@ -27,7 +35,7 @@ template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, typename cd_dtype_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
int* grouped_layout,
cute::TmaDescriptor* tensor_map_buffer,
@@ -45,7 +53,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
@@ -55,13 +63,13 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0);
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
// Configs
@@ -83,47 +91,41 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Tensor maps on shared and global memory
auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * i);
});
auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cute::TmaDescriptor*>(smem_buffer + static_cast<uint32_t>(sizeof(cute::TmaDescriptor)) * (2 + i));
});
auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; });
auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; });
auto smem_tensor_map_a = reinterpret_cast<cute::TmaDescriptor*>(smem_buffer);
auto smem_tensor_map_b = smem_tensor_map_a + 1;
auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2;
auto gmem_tensor_map_b = gmem_tensor_map_a + 1;
// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
});
auto smem_sfb = PatternVisitor([&](const uint32_t& i) {
auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
});
// Barriers on shared memory
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
auto full_barriers = PatternVisitor([&](const uint32_t& i) {
auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
});
auto empty_barriers = PatternVisitor([&](const uint32_t& i) {
auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
});
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
// Load tensormap A/B to shared memory
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
*smem_tensor_map_a[0] = tensor_map_a_base;
*smem_tensor_map_a[1] = tensor_map_a_base;
*smem_tensor_map_b[0] = tensor_map_b_base;
*smem_tensor_map_b[1] = tensor_map_b_base;
*smem_tensor_map_a = tensor_map_a_base;
*smem_tensor_map_b = tensor_map_b_base;
}
// Initialize barriers
@@ -149,12 +151,15 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
// TMA and MMA pipeline
const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
};
uint32_t iter_idx = 0;
@@ -165,10 +170,7 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
// NOTES: only one thread (or warp) will be used
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base;
const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base;
uint32_t last_group_idx = kNumGroups;
uint32_t prefetched_next_group_idx = kNumGroups; // Track which group was prefetched
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
@@ -179,63 +181,26 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
const uint32_t& m_idx = m_block_idx * BLOCK_M;
const uint32_t& n_idx = n_block_idx * BLOCK_N;
const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
const uint32_t m_idx = m_block_idx * BLOCK_M;
const uint32_t n_idx = n_block_idx * BLOCK_N;
if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) {
const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1;
const uint32_t& next_stage_idx = stage_idx ^ 1;
last_group_idx = scheduler.current_group_idx;
if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) {
last_group_idx = scheduler.current_group_idx;
// Check if the current group matches the prefetched group
// If not, we need to prepare the correct tensor map for the current group
if (scheduler.current_num_valid_groups > 0 &&
scheduler.current_group_idx != prefetched_next_group_idx) {
// The prefetched tensor map doesn't match current group
// This happens when block count is small (< num_SMs) and scheduler skips groups
// Need to prepare the correct tensor map for current group
// Use scheduler.current_k_cumsum which correctly tracks k offset even when groups are skipped
const uint64_t current_k_offset = scheduler.current_k_cumsum;
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[stage_idx],
gmem_a_ptr + current_k_offset * shape_m);
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[stage_idx],
gmem_b_ptr + current_k_offset * shape_n);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[stage_idx],
scheduler.current_shape_k, scheduler.current_shape_k);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[stage_idx],
scheduler.current_shape_k, scheduler.current_shape_k);
*(gmem_tensor_map_a[stage_idx]) = *(smem_tensor_map_a[stage_idx]);
*(gmem_tensor_map_b[stage_idx]) = *(smem_tensor_map_b[stage_idx]);
// NOTE: Don't call tensor_map_release_cta() here!
// We're preparing the current tensor map, not the next one.
// It will be acquired immediately in the "Get current tensor map" section below.
}
// Directly update current tensor map
const uint64_t current_k_offset = scheduler.current_k_cumsum;
ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m);
ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n);
ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k);
ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k);
*(gmem_tensor_map_a) = *(smem_tensor_map_a);
*(gmem_tensor_map_b) = *(smem_tensor_map_b);
ptx::tensor_map_release_gpu();
// Prepare next tensor map (prefetch for next group)
if (scheduler.next_group_idx < kNumGroups) {
// Calculate next group's k offset using scheduler-provided information
// This ensures consistency even when groups are skipped
const uint64_t next_k_offset = static_cast<uint64_t>(scheduler.current_k_cumsum) + scheduler.current_shape_k;
tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + next_k_offset * shape_m);
tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + next_k_offset * shape_n);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k);
*(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]);
*(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]);
tensor_map_release_cta();
prefetched_next_group_idx = scheduler.next_group_idx; // Record which group was prefetched
} else {
prefetched_next_group_idx = kNumGroups; // No more groups to prefetch
}
// Get current tensor map
if (scheduler.current_num_valid_groups > 0) {
tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]);
tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]);
current_tensor_map_a = gmem_tensor_map_a[stage_idx];
current_tensor_map_b = gmem_tensor_map_b[stage_idx];
}
// Immediately acquire current tensor map
ptx::tensor_map_acquire_gpu(gmem_tensor_map_a);
ptx::tensor_map_acquire_gpu(gmem_tensor_map_b);
}
#pragma unroll kNumPipelineUnrolls
@@ -246,12 +211,14 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
// Issue TMA
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t& k_idx = k_block_idx * BLOCK_K;
const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
tma_copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
const uint32_t k_idx = k_block_idx * BLOCK_K;
const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base);
const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base);
tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
tma::copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
}
}
@@ -278,9 +245,9 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Accumulation for WGMMA or CUDA promotion
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K);
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K);
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
float2 scales_b[WGMMA::kNumAccum / 4];
@@ -302,30 +269,30 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0);
auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1);
auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0);
auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1);
// Read B scales
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
scales_b[i] = ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
scales_b[i] = ptx::ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(stage_idx);
@@ -348,12 +315,12 @@ sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
// Store to D shared memory
const auto& smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
const auto& smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
const auto smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
const auto smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);

View File

@@ -10,17 +10,21 @@
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
if (num_former_iters == kNumFormerIters) {
func(cute::Int<kNumFormerIters>{});
return;
@@ -35,12 +39,12 @@ template <cute::UMMA::Major kMajorSFB,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
uint32_t kNumStages, uint32_t kNumLastStages,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType,
typename epilogue_type_t>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
@@ -50,10 +54,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
DG_STATIC_ASSERT(
math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or
(math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Types
using WGMMA = typename FP8MMASelector<BLOCK_N>::type;
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
@@ -64,23 +70,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Shared memory
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K);
const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K);
const uint32_t& smem_sfb_size = align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K);
const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K);
const uint32_t smem_sfb_size = math::align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
// NOTES: Make sure we have enough shared memory for WGMMA padding
static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
// Configs
const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K);
const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = get_lane_idx();
const uint32_t lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
@@ -97,22 +103,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = PatternVisitor([&](const uint32_t& i) {
auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
@@ -136,9 +142,12 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
@@ -177,15 +186,15 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t k_idx = k_block_idx * BLOCK_K;
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a, batch_idx);
tma_copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
num_tma_multicast_a);
// Issue TMA B
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b, batch_idx);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
@@ -206,8 +215,8 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
auto b_desc = make_smem_desc(smem_b[0], 1);
auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1);
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
@@ -225,14 +234,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto previous_group_offset = scheduler.template get_global_idx<true, IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
auto previous_group_offset = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb));
ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]);
}
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
@@ -259,22 +268,22 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Skip useless computations
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
// The compiler must know the dynamic variable `num_former_iters`'s real value
constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// Dispatch `num_former_iters` and launch MMAs
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
#pragma unroll 8
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
// Read B scales
float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1;
float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales);
scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales);
// Wait TMA arrivals
full_barriers[stage_idx]->wait(phase);
@@ -286,25 +295,25 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
WGMMA::wgmma(a_desc, b_desc, accum, k);
}
warpgroup_commit_batch();
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
@@ -325,7 +334,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters;
const bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
@@ -399,7 +408,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
}
// NOTES: only 16 lanes' addresses are used
SM90_U32x2_STSM_N<nv_bfloat162>::copy(
ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr

View File

@@ -7,36 +7,31 @@
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/mma_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
// ReSharper disable once CppNotAllPathsReturnValue
template <uint32_t kHeadDim>
static constexpr int to_swizzle_cute_type() {
DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
if constexpr (kHeadDim == 32)
return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
if constexpr (kHeadDim == 64)
return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
if constexpr (kHeadDim == 128)
return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
}
template <uint32_t kNumHeads, uint32_t kHeadDim,
bool kIsCompressedLogits,
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
uint32_t kNumSMs,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
typename logits_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const uint32_t max_seqlen_k, const uint64_t stride_logits,
const uint32_t max_seqlen_k, const uint32_t stride_logits,
uint32_t* cu_seq_len_k_start,
uint32_t* cu_seq_len_k_end,
float* logits,
logits_dtype_t* logits,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
@@ -44,10 +39,10 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// TODO: consider TMA multicast
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
// Q should be load only at once for a block
const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q);
const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
// Types
using WGMMA = typename FP8MMASelector<BLOCK_Q * kNumHeads>::type;
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_Q * kNumHeads>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Prefetch TMA descriptors
@@ -74,19 +69,19 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Data on shared memory
auto smem_q = PatternVisitor([&](const uint32_t& i) {
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
@@ -94,13 +89,13 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// TMA barriers
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
// Initialize barriers
const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
if (is_tma_load_warp and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
@@ -123,38 +118,43 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
constexpr uint32_t kNumMathRegisters = 112;
// Block scheduler
uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0;
const auto& get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + gridDim.x, q_iter_idx + 1};
const auto sm_idx = blockIdx.x;
uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + kNumSMs, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx);
seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx);
const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = cu_seq_len_k_start[q_idx];
seq_k_end[i] = cu_seq_len_k_end[q_idx];
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
// TMA alignment requirements for SF KV
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
start, ceil_div(end - start, BLOCK_KV)}; // Task info
start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
};
// KV pipeline
uint32_t num_total_kv_blocks = 0;
const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
return {
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
};
};
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
@@ -165,8 +165,8 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
// Prefetch
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
tma_copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
tma_copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
};
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
@@ -192,9 +192,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
tma_copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
@@ -212,7 +212,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const auto& thread_idx = threadIdx.x % kNumMathThreads;
const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
const auto& warpgroup_idx = warp_idx / 4;
const auto& lane_idx = get_lane_idx();
const auto& lane_idx = ptx::get_lane_idx();
float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
const auto& warp_offset = warp_idx * 16;
@@ -230,7 +230,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
}
// Compute over KV blocks
@@ -242,29 +242,31 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
// Issue WGMMA
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K,
to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_a = mma::sm90::make_smem_desc(
smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_b = mma::sm90::make_smem_desc(
smem_q[q_stage_idx] + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_wait<0>();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
@@ -278,7 +280,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
auto shifted_accum = accum + i * kNumAccumPerReduce;
const auto& transform = [&](const uint32_t& j) {
const auto transform = [&](const uint32_t& j) {
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
};
@@ -302,16 +304,15 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
}
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
const uint32_t& q_idx = block_q_idx * BLOCK_Q + i;
const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
if constexpr (kIsCompressedLogits) {
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0;
logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_0);
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1;
logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_1);
} else {
logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0;
logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1;
logits[q_offset + kv_offset + v_0_offset] = static_cast<logits_dtype_t>(v_0);
logits[q_offset + kv_offset + v_1_offset] = static_cast<logits_dtype_t>(v_1);
}
}
}

View File

@@ -6,133 +6,43 @@
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
namespace deep_gemm {
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
__global__ __launch_bounds__(32, 1)
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
const uint32_t* context_lens, uint32_t* schedule_metadata) {
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
const uint32_t lane_idx = get_lane_idx();
uint32_t num_segs[kAlignedBatchSize / 32];
#pragma unroll
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
const uint32_t q_idx = k * 32 + lane_idx;
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0);
num_segs[k] = ceil_div(context_len, SPLIT_KV);
}
__shared__ uint32_t prefix_sum[kAlignedBatchSize];
uint32_t sum = 0;
#pragma unroll
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
uint32_t x = num_segs[k];
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset);
x += (lane_idx >= offset ? y : 0);
}
x += sum;
prefix_sum[k * 32 + lane_idx] = x;
sum = __shfl_sync(0xffffffff, x, 31);
}
const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t q_idx = 0;
while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts)
++ q_idx;
const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]);
__syncwarp();
schedule_metadata[sm_idx * 2] = q_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
}
}
template <uint32_t kNextN, bool kIsContextLens2D,
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit>
struct PagedMQALogitsScheduler {
uint32_t batch_size;
const uint32_t* context_lens;
uint32_t current_q_idx, current_kv_idx;
uint32_t end_q_idx, end_kv_idx;
uint32_t current_num_kv;
__device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) {
const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0;
}
__device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx,
const uint32_t* context_lens, const uint32_t* schedule_meta) {
this->batch_size = batch_size;
this->context_lens = context_lens;
const auto& current_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx);
const auto& end_pack = __ldg(reinterpret_cast<const uint2*>(schedule_meta) + sm_idx + 1);
current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
current_num_kv = get_num_kv(current_q_idx);
}
__device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) {
q_idx = current_q_idx;
kv_idx = current_kv_idx;
num_kv = current_num_kv;
if (q_idx == end_q_idx and kv_idx == end_kv_idx)
return false;
current_kv_idx += kNumBlocksPerSplit;
if (current_kv_idx >= current_num_kv) {
++ current_q_idx;
current_kv_idx = 0;
current_num_kv = get_num_kv(current_q_idx);
}
return true;
}
__device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const {
return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx;
}
};
using namespace deep_gemm::sm90;
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
typename logits_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint64_t logits_stride, const uint64_t block_table_stride,
const uint32_t* context_lens, float* logits,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
// Types
using WGMMA = typename FP8MMASelector<kNextN * kNumHeads>::type;
using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const auto& warpgroup_idx = warp_idx / 4;
const auto& lane_idx = get_lane_idx();
const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
@@ -150,15 +60,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
@@ -166,31 +76,31 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q data and barriers on shared memory
auto smem_q = PatternVisitor([&](const uint32_t& i) {
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_weights = PatternVisitor([&](const uint32_t& i) {
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
// Separate math warpgroups and tma load warps into KV groups
// Each math warpgroup corresponds to a tma load warp
const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
// Per group KV data and barriers on shared memory
const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
auto smem_kv = PatternVisitor([&](const uint32_t& i) {
const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
});
auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) {
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
// Initialize barriers
if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
@@ -218,15 +128,19 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
constexpr uint32_t kNumTMARegisters = 64;
constexpr uint32_t kNumMathRegisters = 104;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Scheduler
auto scheduler = PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups>(batch_size, blockIdx.x, context_lens, schedule_meta);
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, BLOCK_KV, kNumMathWarpGroups, 1>(
blockIdx.x, context_lens, schedule_meta);
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
// Q and KV pipeline
const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
};
const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
};
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
@@ -237,10 +151,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
if (kv_group_idx >= kNumMathWarpGroups)
return;
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
if (kv_group_idx == 0 and cute::elect_one_sync()) {
tma_copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
tma_copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx);
tma::copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
tma::copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
}
};
@@ -259,7 +173,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
while (fetched_next_task) {
// Prefetch next Q when current Q changes
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1));
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1));
q_idx = next_q_idx;
kv_idx = next_kv_idx;
num_kv = next_num_kv;
@@ -276,9 +190,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
__ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0);
block_table[q_idx * static_cast<uint64_t>(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0);
}
const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
// Wait KV consumer release
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
@@ -286,10 +200,10 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
// Issue TMA KV
if (cute::elect_one_sync()) {
tma_copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
tma_copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
@@ -301,9 +215,9 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
const auto& sub_warp_offset = (warp_idx % 4) * 16;
const auto& v_0_offset = lane_idx / 4 + 0;
const auto& v_1_offset = lane_idx / 4 + 8;
const auto sub_warp_offset = (warp_idx % 4) * 16;
const auto v_0_offset = lane_idx / 4 + 0;
const auto v_1_offset = lane_idx / 4 + 8;
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
uint32_t q_idx = batch_size, kv_idx;
@@ -326,7 +240,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
for (uint32_t i = 0; i < kNextN; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
}
}
@@ -335,7 +249,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
auto kv_offset = q_idx * kNextN * static_cast<uint64_t>(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
// Wait TMA KV arrival
@@ -347,25 +261,29 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_a = mma::sm90::make_smem_desc(
smem_kv[kv_stage_idx] + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_b = mma::sm90::make_smem_desc(
smem_q[q_stage_idx] + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
warpgroup_commit_batch();
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
ptx::warpgroup_fence_operand(accum[i]);
// Read per-KV scales
float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
// Wait WGMMA
warpgroup_wait<0>();
ptx::warpgroup_wait<0>();
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
@@ -378,7 +296,7 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
auto shifted_accum = accum + i * kNumAccumPerReduce;
const auto& transform = [&](const uint32_t& j) {
const auto transform = [&](const uint32_t& j) {
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
};
@@ -396,15 +314,15 @@ void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
// Inter-thread reduction
#pragma unroll
for (uint32_t j = 0; j < 2; ++ j) {
const auto& offset = static_cast<int>(1u << j);
const auto offset = static_cast<int>(1u << j);
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
}
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
logits[kv_offset + i * logits_stride + v_0_offset] = v_0;
logits[kv_offset + i * logits_stride + v_1_offset] = v_1;
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_0_offset] = static_cast<logits_dtype_t>(v_0);
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_1_offset] = static_cast<logits_dtype_t>(v_1);
}
}
}

View File

@@ -5,20 +5,23 @@
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/reduction.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm90_utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm90;
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
__device__ __forceinline__
CUTLASS_DEVICE
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
@@ -35,7 +38,7 @@ template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
@@ -56,7 +59,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
const auto lane_idx = ptx::get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
@@ -76,17 +79,17 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = PatternVisitor([&](const uint32_t& i) {
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
@@ -101,7 +104,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
}
__syncthreads();
constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
@@ -113,12 +116,15 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 256;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// TMA load warp
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
const auto stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
// Compute offsets
@@ -126,8 +132,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma_copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma_copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
@@ -135,7 +141,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
}
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
const auto& stage_idx = s % kNumStages;
const auto stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
}
} else if (warp_idx < kNumMathThreads / 32) {
@@ -148,7 +154,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
constexpr uint32_t WGMMA_N = BLOCK_N;
constexpr uint32_t WGMMA_K = 8;
using WGMMA = typename TF32MMASelector<WGMMA_N, true>::type;
using WGMMA = typename mma::sm90::TF32MMASelector<WGMMA_N, true>::type;
float accum[WGMMA::kNumAccum] = {0};
constexpr uint32_t kNumBankGroupBytes = 16;
@@ -196,14 +202,14 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
}
warpgroup_wait<0>();
ptx::warpgroup_wait<0>();
if (s > 0)
empty_barriers[(s - 1) % kNumStages]->arrive();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
warpgroup_arrive();
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
@@ -213,18 +219,19 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
#pragma unroll
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
auto b_desc = mma::sm90::make_smem_desc(
smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
}
}
warpgroup_commit_batch();
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
warpgroup_fence_operand(accum[i]);
ptx::warpgroup_fence_operand(accum[i]);
}
const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0);
const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1);
const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0);
const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1);
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
if (lane_idx % 4 == 0) {
@@ -233,7 +240,7 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
if (m_idx + 8 < shape_m)
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
}
warpgroup_wait<0>();
ptx::warpgroup_wait<0>();
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
// Write accum to shared memory
@@ -260,8 +267,8 @@ sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
// 0/1 write to the same row, 2/3 write to another row
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
st_shared(smem_ptr, values[0], values[1]);
st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
ptx::st_shared(smem_ptr, values[0], values[1]);
ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, 1);

View File

@@ -3,21 +3,24 @@
#include <cutlass/arch/barrier.h>
#include <cute/arch/cluster_sm90.hpp>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
namespace deep_gemm {
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps>
__global__ __launch_bounds__(kNumWarps * 32, 1)
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps, typename logits_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1)
void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) {
const uint32_t& num_sms = gridDim.x;
const uint32_t& sm_idx = blockIdx.x;
const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
constexpr float neg_inf = -cute::numeric_limits<float>::infinity();
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) {
const uint32_t num_sms = gridDim.x;
const uint32_t sm_idx = blockIdx.x;
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t);
const logits_dtype_t neg_inf = -cute::numeric_limits<logits_dtype_t>::infinity();
// Allocate filled `-inf` shared memory
extern __shared__ __align__(1024) float smem_buffer[];
extern __shared__ __align__(1024) logits_dtype_t smem_buffer[];
#pragma unroll
for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
smem_buffer[i] = neg_inf;
@@ -25,38 +28,42 @@ void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const
__syncthreads();
// Assign sequence to each warp
const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx,
const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
const auto& per = total / num, rem = total % num;
return {start + idx * per + min(idx, rem), per + (idx < rem)};
const auto assign_task = [&](const uint32_t& num, const uint32_t& idx,
const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
const auto per = total / num, rem = total % num;
return {start + idx * per + cute::min(idx, rem), per + (idx < rem)};
};
CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (cute::elect_one_sync()) {
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
const auto& right = min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
const auto right = cute::min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
if (right <= ks or ke <= left) {
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float));
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t));
} else {
if (left < aligned_ks)
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float));
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t));
if (aligned_ke < right)
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float));
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t));
}
}
}
}
__syncwarp();
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN);
const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1;
const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4;
const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
for (uint32_t j = aligned_ks; j < ks; ++ j)
logits[i * stride_logits + j] = neg_inf;
for (uint32_t j = ke; j < aligned_ke; ++ j)

View File

@@ -1,13 +1,16 @@
#pragma once
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
typedef typename Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
typedef typename utils::Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
@@ -15,16 +18,19 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
extern __shared__ float smem_buffer[];
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
const auto tma_aligned_mn = align<uint32_t>(mn, kNumTMAAlignedElems);
const auto tma_aligned_mn = math::align<uint32_t>(mn, kNumTMAAlignedElems);
// Shift into the block
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Load
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
auto in_vec = __ldg(local_sf + i);
auto in_vec = local_sf[i];
const auto& in_values = reinterpret_cast<float*>(&in_vec);
const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
@@ -39,26 +45,29 @@ __global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
}
}
// NOTES: the two kernels below always pack the K dimension
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
extern __shared__ uint32_t smem_buffer[];
// Shapes and strides
constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u);
constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u);
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
const auto tma_aligned_mn = align<uint64_t>(mn, kNumTMAAlignedElems);
const auto tma_aligned_mn = math::align<uint64_t>(mn, kNumTMAAlignedElems);
// Shift into the group
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Load FP32 SFs
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
@@ -66,13 +75,13 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
const auto num_uint4 = num_values / 4;
#pragma unroll
for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
const auto& [x, y, z, w] = __ldg(reinterpret_cast<uint4*>(local_sf) + i);
st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
const auto& [x, y, z, w] = reinterpret_cast<const uint4*>(local_sf)[i];
ptx::st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
}
// Fill unaligned values as well
if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx));
ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]);
__syncthreads();
// Pack into UE8M0 and store
@@ -85,7 +94,7 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
#pragma unroll
for (uint32_t j = 0; j < 4; ++ j) {
const auto sf_k_idx = sf_k_pack_idx * 4 + j;
values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
}
// Pack and store
@@ -101,8 +110,9 @@ __global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, con
template <uint32_t kNumGroups, uint32_t kNumThreads,
uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) {
CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k,
const uint32_t gran_k) {
// Always packing the K dimension
// NOTES: should also assert `mn % 4 == 0` at launch
DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
@@ -120,11 +130,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
// Each warp is responsible for a packed row
const auto warp_idx = threadIdx.x / 32;
const auto lane_idx = get_lane_idx();
const auto lane_idx = ptx::get_lane_idx();
const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
if (warp_idx >= in_block_packed_sf_k)
return;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Make an offset on the input
uint32_t input_offset = 0;
if constexpr (kNumGroups > 1) {
@@ -134,18 +147,18 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i) {
const auto group_idx = lane_idx * 4 + i;
group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0;
group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0;
}
__syncwarp();
// Make the offset
sf_k = 0;
auto sum_packed_sf_k = 0;
uint32_t sum_packed_sf_k = 0;
#pragma unroll
for (uint32_t i = 0; i < kNumGroups; ++ i) {
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4);
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4);
sf_k += sf_k_in_group;
sum_packed_sf_k += ceil_div(sf_k_in_group, 4u);
sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u);
if (packed_sf_k_idx < sum_packed_sf_k)
break;
if (const auto remainder = sf_k_in_group % 4; remainder > 0)
@@ -153,14 +166,14 @@ __global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
}
}
for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
// Load
uint4 values[4];
#pragma unroll
for (uint32_t j = 0; j < 4; ++ j) {
values[j] = make_uint4(0, 0, 0, 0);
if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
values[j] = __ldg(reinterpret_cast<uint4*>(sf + sf_k_idx * mn) + mn_idx);
values[j] = reinterpret_cast<const uint4*>(sf + sf_k_idx * mn)[mn_idx];
}
// Pack and store

View File

@@ -0,0 +1,255 @@
#pragma once
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::layout {
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
T num_experts_per_rank, T block_m) {
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
return math::constexpr_align(
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (block_m - 1),
block_m);
}
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) {
return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast<T>(128));
}
// Per-token source metadata for combine write-back
struct TokenSrcMetadata {
uint32_t rank_idx;
uint32_t token_idx;
uint32_t topk_idx;
};
struct Workspace {
void* base;
uint32_t num_ranks, num_experts;
uint32_t num_experts_per_rank;
uint32_t num_max_tokens_per_rank;
uint32_t num_max_recv_tokens_per_expert;
// Pool capacity: all local experts share a contiguous token pool
uint32_t num_max_pool_tokens;
uint32_t num_max_pool_blocks;
// For both grid barrier and NVLink barrier
static constexpr uint64_t kNumBarrierSignalBytes = 32;
CUTLASS_HOST_DEVICE
Workspace(void* base,
const uint32_t& num_ranks,
const uint32_t& num_experts,
const uint32_t& num_max_tokens_per_rank,
const uint32_t& num_topk,
const uint32_t& block_m):
base(base),
num_ranks(num_ranks), num_experts(num_experts),
num_max_tokens_per_rank(num_max_tokens_per_rank) {
num_experts_per_rank = num_experts / num_ranks;
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
num_max_pool_tokens = get_num_max_pool_tokens(
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank, block_m);
num_max_pool_blocks = num_max_pool_tokens / block_m;
DG_UNIFIED_ASSERT(num_max_tokens_per_rank % block_m == 0);
}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes() const {
uint64_t num_bytes = 0;
// Barrier
num_bytes += kNumBarrierSignalBytes;
// Expert send/recv count
num_bytes += num_experts * sizeof(uint64_t) * 2;
// Expert recv count sum
num_bytes += num_experts_per_rank * sizeof(uint64_t);
// L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask)
num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t);
// L2 block arrival mask
num_bytes += num_max_pool_blocks * sizeof(uint64_t);
// Dispatch pulling source token-topk
num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int);
// Combine push source indices
num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata);
// Align to TMA descriptor requirements
num_bytes = math::align<uint64_t>(num_bytes, 16);
return num_bytes;
}
CUTLASS_HOST_DEVICE
void* get_end_ptr() const {
return math::advance_ptr(base, get_num_bytes());
}
// Grid sync counters: `kNumBarrierSignalBytes` layout
// [ 0..15]: 4 x `uint32_t` grid sync counters
// [16..20]: `uint32_t` NVLink barrier counter
// [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1)
static constexpr uint32_t kNumMaxGridSyncCounters = 4;
template <uint32_t kIndex = 0>
CUTLASS_DEVICE
uint32_t* get_grid_sync_count_ptr() const {
DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds");
return static_cast<uint32_t*>(base) + kIndex;
}
CUTLASS_DEVICE
uint32_t* get_nvl_barrier_counter_ptr() const {
return static_cast<uint32_t*>(base) + kNumMaxGridSyncCounters;
}
CUTLASS_DEVICE
int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const {
// NOTES: the signal is signed, as we may minus
return math::advance_ptr<int>(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int));
}
CUTLASS_DEVICE
uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const {
return math::advance_ptr<uint64_t>(base, kNumBarrierSignalBytes) + expert_idx;
}
CUTLASS_DEVICE
uint64_t* get_expert_recv_count_ptr(
const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const {
return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx;
}
CUTLASS_DEVICE
uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const {
return get_expert_send_count_ptr(num_experts * 2) + expert_idx;
}
CUTLASS_DEVICE
uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const {
const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank);
return reinterpret_cast<uint32_t*>(base) + pool_block_idx;
}
CUTLASS_DEVICE
uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const {
// Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned
const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u));
return reinterpret_cast<uint64_t*>(base) + pool_block_idx;
}
// For dispatch pulling
CUTLASS_DEVICE
uint32_t* get_src_token_topk_idx_ptr(
const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const {
const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks);
return reinterpret_cast<uint32_t*>(base) +
expert_idx * (num_ranks * num_max_recv_tokens_per_expert) +
rank_idx * num_max_recv_tokens_per_expert + token_idx;
}
// For combine usages
CUTLASS_DEVICE
TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const {
const auto base = reinterpret_cast<TokenSrcMetadata*>(get_src_token_topk_idx_ptr(num_experts_per_rank));
return base + pool_token_idx;
}
};
struct Data {
uint32_t num_bytes;
bool require_tma_alignment;
void* base;
CUTLASS_HOST_DEVICE
constexpr explicit Data(
const uint32_t& num_bytes,
const bool& require_tma_alignment = true,
void* base = nullptr) :
num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) {
DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment);
}
template <typename dtype_t = uint32_t>
CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const {
return static_cast<dtype_t>(num_bytes);
}
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
return static_cast<dtype_t*>(base);
}
CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) {
base = ptr;
}
};
struct Buffer {
Data data_layout;
uint32_t num_ranks;
uint32_t num_max_tokens_per_rank;
void* base;
CUTLASS_HOST_DEVICE
Buffer(const Data& data_layout,
const uint32_t& num_ranks,
const uint32_t& max_num_tokens_per_rank,
void* base = nullptr) :
data_layout(data_layout),
num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank),
base(base) {}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes_per_rank() const {
return num_max_tokens_per_rank * data_layout.get_num_bytes<uint64_t>();
}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes() const {
return get_num_bytes_per_rank() * num_ranks;
}
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
return static_cast<dtype_t*>(base);
}
CUTLASS_HOST_DEVICE
void* get_end_ptr() const {
return math::advance_ptr(base, get_num_bytes());
}
CUTLASS_HOST_DEVICE
Buffer get_rank_buffer(const uint32_t& rank_idx) const {
return {
data_layout,
1, num_max_tokens_per_rank,
math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx)
};
}
CUTLASS_HOST_DEVICE
Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const {
DG_DEVICE_ASSERT(num_ranks == 1 or global);
return Data(
data_layout.num_bytes,
data_layout.require_tma_alignment,
math::advance_ptr(base, data_layout.get_num_bytes<uint64_t>() * token_idx)
);
}
};
} // namespace deep_gemm::layout

View File

@@ -0,0 +1,41 @@
#pragma once
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::layout {
constexpr static uint32_t kNumMaxRanks = 72;
template <uint32_t kNumRanks = kNumMaxRanks>
struct SymBuffer {
int64_t base;
int64_t offsets[kNumMaxRanks];
uint32_t rank_idx;
DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks");
SymBuffer() = default;
template <typename Container>
explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) {
const auto size = static_cast<uint32_t>(c.size());
base = c[rank_idx];
for (uint32_t i = 0; i < kNumMaxRanks; ++ i)
offsets[i] = i < size ? (c[i] - base) : 0;
}
#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__)
template <typename ptr_t = void*>
CUTLASS_DEVICE ptr_t get_base_ptr() const {
return reinterpret_cast<ptr_t>(base);
}
template <typename ptr_t>
CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const {
int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast<int64_t>(ptr);
return *reinterpret_cast<ptr_t*>(&mapped_ptr);
}
#endif
};
} // namespace deep_gemm::layout

View File

@@ -0,0 +1,151 @@
#pragma once
#include <cute/atom/mma_traits_sm100.hpp>
#include <cute/arch/mma_sm100_umma.hpp>
#include <deep_gemm/common/exception.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
namespace deep_gemm::mma::sm100 {
/// Shared memory descriptor
CUTLASS_DEVICE
cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr,
const uint32_t& stride_byte_offset, const uint32_t& leading_byte_offset) {
cute::UMMA::SmemDescriptor desc;
// Set the version for SM100
desc.version_ = 1;
// Legacy mode
desc.lbo_mode_ = 0;
// Layout
desc.layout_type_ = static_cast<uint8_t>(layout);
// Start address
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
// Base offset
desc.base_offset_ = 0;
// SBO and LBO
desc.stride_byte_offset_ = stride_byte_offset >> 4;
desc.leading_byte_offset_ = leading_byte_offset >> 4;
return desc;
}
CUTLASS_DEVICE
cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) {
// NOTES: the UTCCP layout is K-major by default
// Atom size: 8 x 128 bits
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero
return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0);
}
CUTLASS_DEVICE
void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) {
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
}
CUTLASS_DEVICE
static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
}
/// UMMA descriptors
// ReSharper disable once CppNotAllPathsReturnValue
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
kSwizzleMode == 32 or kSwizzleMode == 64 or
kSwizzleMode == 128, "Invalid swizzling mode");
// A special case
if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
}
// Normal cases
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B;
if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B;
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
CUTLASS_DEVICE
constexpr uint32_t get_umma_desc_stride_k() {
return kMajorMode == cute::UMMA::Major::K ? 1 : tma::get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
CUTLASS_DEVICE
uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
CUTLASS_DEVICE
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
const auto layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
const auto num_non_contiguous = 128 / get_atom_base(layout_type);
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
// also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
const uint32_t leading_byte_offset = 0;
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
} else {
constexpr uint32_t BLOCK_MN_ATOM = tma::get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
// Must have no in-atom MN-idx
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
if constexpr (kSwizzleMode == 16)
math::swap(stride_byte_offset, leading_byte_offset);
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
}
}
CUTLASS_DEVICE uint64_t make_runtime_instr_desc_with_sf_id(
cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
}
CUTLASS_DEVICE void update_instr_desc_with_umma_n(
cute::UMMA::InstrDescriptorBlockScaled& desc, const uint32_t& umma_n) {
desc.n_dim_ = umma_n >> 3;
}
CUTLASS_DEVICE void update_instr_desc_with_umma_n(
cute::UMMA::InstrDescriptor& desc, const uint32_t& umma_n) {
desc.n_dim_ = umma_n >> 3;
}
} // namespace deep_gemm::mma::sm100

View File

@@ -0,0 +1,293 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/mma_sm90_desc.hpp>
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <cute/arch/mma_sm100_desc.hpp>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::mma::sm90 {
/// MMA
template <int N_, typename MMA>
struct FP8MMA {
template <size_t ...Idx>
CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_ / 2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
template <int N>
struct FP8MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN();
}
static constexpr auto select_type() {
return FP8MMA<N, decltype(select_mma())>();
}
using type = decltype(select_type());
};
template <int N_, typename MMA>
struct BF16MMA {
template <size_t ...Idx>
CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 16;
static constexpr int kNumAccum = M * N / 128;
};
template <cute::UMMA::Major kMajor>
constexpr cute::SM90::GMMA::Major to_sm90_major() {
DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness");
return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN;
}
template <int N,
cute::UMMA::Major kMajorA = cute::UMMA::Major::K,
cute::UMMA::Major kMajorB = cute::UMMA::Major::K>
struct BF16MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
constexpr auto kGMMAMajorA = to_sm90_major<kMajorA>();
constexpr auto kGMMAMajorB = to_sm90_major<kMajorB>();
if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
}
static constexpr auto select_type() {
return BF16MMA<N, decltype(select_mma())>();
}
using type = decltype(select_type());
};
template <int N_, typename MMA>
struct TF32MMARS {
template <size_t ...Idx>
CUTLASS_DEVICE static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
CUTLASS_DEVICE static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 8;
static constexpr int kNumAccum = M * N / 128;
};
template <int N, bool kUseRS = true>
struct TF32MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (kUseRS) {
if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
}
}
static constexpr auto select_type() {
if constexpr (kUseRS) {
return TF32MMARS<N, decltype(select_mma())>();
} else {
DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
}
}
using type = decltype(select_type());
};
/// Shared memory descriptor
template <class PointerType>
CUTLASS_DEVICE cute::GmmaDescriptor
make_smem_desc(PointerType smem_ptr, const int& layout_type,
const uint32_t& leading_byte_offset = 0,
const uint32_t& stride_byte_offset = 1024) {
// NOTES: the default LBO and SBO are for K-major types
cute::GmmaDescriptor desc;
const auto uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
}
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
constexpr uint32_t get_inner_block_atom_size() {
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
CUTLASS_DEVICE
constexpr uint32_t get_gmma_desc_stride_k() {
return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
}
// ReSharper disable once CppNotAllPathsReturnValue
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, typename dtype_t>
constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() {
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
kSwizzleMode == 32 or kSwizzleMode == 64 or
kSwizzleMode == 128, "Invalid swizzling mode");
// Normal cases
if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32;
if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64;
if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128;
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
CUTLASS_DEVICE
uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) {
return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
CUTLASS_DEVICE
cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
const uint32_t stride_k = get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
const auto layout_type = to_gmma_layout_type<kMajorMode, kSwizzleMode, dtype_t>();
constexpr uint32_t num_non_contiguous = 128 / 16;
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
const uint32_t leading_byte_offset = 0;
return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
leading_byte_offset, stride_byte_offset);
} else {
constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
// Must have no in-atom MN-idx
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
if constexpr (kSwizzleMode == 16)
math::swap(stride_byte_offset, leading_byte_offset);
return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
leading_byte_offset, stride_byte_offset);
}
}
// ReSharper disable once CppNotAllPathsReturnValue
template <uint32_t kHeadDim>
static constexpr int to_swizzle_cute_type() {
DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling");
if constexpr (kHeadDim == 32)
return static_cast<int>(cute::SM90::GMMA::LayoutType::B32);
if constexpr (kHeadDim == 64)
return static_cast<int>(cute::SM90::GMMA::LayoutType::B64);
if constexpr (kHeadDim == 128)
return static_cast<int>(cute::SM90::GMMA::LayoutType::B128);
}
} // namespace deep_gemm::mma::sm90

View File

@@ -0,0 +1,247 @@
#pragma once
#include <cuda/std/cstdint>
#include <cuda_bf16.h>
namespace deep_gemm::ptx {
// Compatibility: 256 bits LD/ST instructions
#if defined(CUDART_VERSION) and CUDART_VERSION >= 13000
using longlong4_t = longlong4_32a;
#define make_longlong4_t make_longlong4_32a
#else
struct alignas(32) longlong4_t { long long x, y, z, w; };
CUTLASS_HOST_DEVICE longlong4_t make_longlong4_t(
const long long& x, const long long& y, const long long& z, const long long& w) {
return {x, y, z, w};
}
#endif
/// LD/ST matrix
// TODO: remove `struct`
struct SM90_U32x2_LDSM_N {
CUTLASS_DEVICE static void
copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst_0), "=r"(dst_1)
: "l"(__cvta_generic_to_shared(smem_src)));
}
};
struct SM90_U32x4_LDSM_N {
CUTLASS_DEVICE static void
copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
: "l"(__cvta_generic_to_shared(smem_src)));
}
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
CUTLASS_DEVICE static void
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1]));
}
};
template <typename dtype_t>
struct SM90_U32x4_STSM_T {
CUTLASS_DEVICE static void
copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) {
DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
const uint32_t src[4] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1),
*reinterpret_cast<uint32_t*>(&src_2), *reinterpret_cast<uint32_t*>(&src_3)};
asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n"
:: "l"(__cvta_generic_to_shared(smem_dst)),
"r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3]));
}
};
template <typename dtype_t>
struct SM100_U8x4_STSM_T {
__device__ __forceinline__ static void
copy(dtype_t src_0, void* smem_dst) {
DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
const uint32_t src = *reinterpret_cast<uint32_t*>(&src_0);
asm volatile("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n"
:: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src));
}
};
template <typename dtype_t>
struct SM100_U8x8_STSM_T {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype");
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n"
:: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1]));
}
};
/// Shared memory
CUTLASS_DEVICE uint32_t ld_shared(const uint32_t* ptr) {
uint32_t ret;
asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
CUTLASS_DEVICE float2 ld_shared(const float2* ptr) {
float2 ret;
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
CUTLASS_DEVICE float4 ld_shared(const float4* ptr) {
float4 ret;
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
CUTLASS_DEVICE uint4 ld_shared(const uint4* ptr) {
uint4 ret;
asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
CUTLASS_DEVICE float ld_shared(const float* ptr) {
float ret;
asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr)));
return ret;
}
CUTLASS_DEVICE void st_shared(const float* ptr, float val) {
asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val));
}
CUTLASS_DEVICE void st_shared(const float2* ptr, float2 val) {
asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y));
}
CUTLASS_DEVICE void st_shared(const uint32_t* ptr, uint32_t val) {
asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val));
}
CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y) {
asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y));
}
CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) {
asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w));
}
CUTLASS_DEVICE void st_shared(const __int128_t* ptr, __int128_t val) {
asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val));
}
CUTLASS_DEVICE void st_shared_bulk(void* smem_ptr, const uint32_t& num_bytes) {
// `size` must be 64-bit before PTX ISA 9.0
asm volatile("st.bulk.weak.shared::cta [%0], %1, 0;" ::
"l"(__cvta_generic_to_shared(smem_ptr)), "l"(static_cast<uint64_t>(num_bytes)));
}
/// Global memory
CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) {
uint64_t ret;
asm volatile("ld.volatile.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) {
uint32_t ret;
asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) {
uint64_t ret;
asm volatile("ld.acquire.sys.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) {
asm volatile("st.L1::no_allocate.relaxed.sys.u64 [%0], %1;" :: "l"(ptr), "l"(value));
}
/// Atomics
CUTLASS_DEVICE uint64_t atomic_add(const uint64_t* ptr, const uint64_t& value) {
uint64_t ret;
asm volatile("atom.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value));
return ret;
}
CUTLASS_DEVICE uint64_t atomic_add_sys(const uint64_t* ptr, const uint64_t& value) {
uint64_t ret;
asm volatile("atom.sys.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value));
return ret;
}
CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& value) {
uint32_t ret;
asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value));
return ret;
}
__forceinline__ __device__ void red_add(const uint32_t* ptr, const uint32_t& value) {
asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
}
CUTLASS_DEVICE void red_or_rel_sys(const uint64_t* ptr, const uint64_t& value) {
asm volatile("red.release.sys.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value));
}
CUTLASS_DEVICE void red_or_rel_gpu(uint64_t* ptr, const uint64_t& value) {
asm volatile("red.release.gpu.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value));
}
CUTLASS_DEVICE void red_add_rel(const uint32_t* ptr, const uint32_t& value) {
asm volatile("red.release.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value));
}
CUTLASS_DEVICE void red_add_rel_sys(const int* ptr, const int& value) {
asm volatile("red.release.sys.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value));
}
CUTLASS_DEVICE int ld_acq_sys(const int* ptr) {
int ret;
asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
CUTLASS_DEVICE uint32_t ld_acq_sys(const uint32_t* ptr) {
uint32_t ret;
asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return ret;
}
CUTLASS_DEVICE uint64_t ld_acq_gpu(const uint64_t* ptr) {
uint64_t ret;
asm volatile("ld.acquire.gpu.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr));
return ret;
}
/// Predicated loads
CUTLASS_DEVICE longlong4_t ld_gez_pred(const longlong4_t* ptr, const int& pred) {
longlong4_t ret = make_longlong4_t(0, 0, 0, 0);
asm volatile(
"{\n\t"
" .reg .pred p;\n\t"
" setp.ge.s32 p, %5, 0;\n\t"
" @p ld.global.L2::256B.v4.s64 {%0, %1, %2, %3}, [%4];\n\t"
"}"
: "+l"(ret.x), "+l"(ret.y), "+l"(ret.z), "+l"(ret.w)
: "l"(ptr), "r"(pred)
: "memory");
return ret;
}
/// Prefetch
CUTLASS_DEVICE void prefetch_l1(void *ptr) {
asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr));
}
} // namespace deep_gemm::ptx

View File

@@ -0,0 +1,168 @@
#pragma once
namespace deep_gemm::ptx {
/// UMMA versions with relaxed assertions
struct SM100_MMA_F16BF16_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_F16BF16_2x1SM_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_MXF8F6F4_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_MXF8F6F4_2x1SM_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_F8F6F4_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_F8F6F4_2x1SM_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_MXF4_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)
"tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t"
#else
"tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t"
#endif
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_F16BF16_WS_SS {
CUTLASS_DEVICE static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
/// Tensor memory operations
CUTLASS_DEVICE void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
CUTLASS_DEVICE void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
} // namespace deep_gemm::ptx

View File

@@ -0,0 +1,112 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cute/arch/copy_sm90_desc.hpp>
namespace deep_gemm::ptx {
// Tensor-map instructions
CUTLASS_DEVICE void tensor_map_release_gpu() {
asm volatile ("fence.proxy.tensormap::generic.release.gpu;" ::: "memory");
}
CUTLASS_DEVICE void tensor_map_acquire_gpu(const cute::TmaDescriptor* gmem_desc_ptr) {
auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
asm volatile ("fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" :: "l"(gmem_int_desc) : "memory");
}
CUTLASS_DEVICE void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
}
CUTLASS_DEVICE void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3)))
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
#else
DG_STATIC_ASSERT(false, "Invalid CUDA version");
#endif
}
/// TMA instructions
CUTLASS_DEVICE void mbarrier_arrive(
cutlass::arch::ClusterTransactionBarrier* ptr) {
asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" ::
"r"(static_cast<uint32_t>(__cvta_generic_to_shared(ptr))));
}
CUTLASS_DEVICE void mbarrier_arrive_and_set_tx(
cutlass::arch::ClusterTransactionBarrier* ptr, const uint32_t& num_bytes) {
asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" ::
"r"(num_bytes), "r"(static_cast<uint32_t>(__cvta_generic_to_shared(ptr))));
}
CUTLASS_DEVICE void mbarrier_wait_and_flip_phase(
cutlass::arch::ClusterTransactionBarrier* ptr, uint32_t& phase) {
asm volatile(
"{\n\t"
".reg .pred P1; \n\t"
"LAB_WAIT: \n\t"
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t"
"@P1 bra DONE; \n\t"
"bra LAB_WAIT; \n\t"
"DONE: \n\t"
"}" ::
"r"(static_cast<uint32_t>(__cvta_generic_to_shared(ptr))),
"r"(phase), "r"(0x989680));
phase ^= 1;
}
CUTLASS_DEVICE void tma_load_1d(
const void* dst_ptr, const void* src_ptr,
cutlass::arch::ClusterTransactionBarrier* mbarrier_ptr,
const uint32_t& num_bytes,
const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_FIRST) {
// NOTES: normally, the loaded part will be evicted soon
asm volatile(
"cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" ::
"r"(static_cast<uint32_t>(__cvta_generic_to_shared(dst_ptr))),
"l"(src_ptr),
"r"(num_bytes),
"r"(static_cast<uint32_t>(__cvta_generic_to_shared(mbarrier_ptr))),
"l"(hint)
: "memory");
}
CUTLASS_DEVICE void tma_store_1d(
const void* dst_ptr, const void* src_ptr, const uint32_t& num_bytes,
const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_NORMAL) {
// NOTES: normally, the stored part will be used soon
asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" ::
"l"(dst_ptr),
"r"(static_cast<uint32_t>(__cvta_generic_to_shared(src_ptr))),
"r"(num_bytes),
"l"(hint)
: "memory");
}
template <int kNumRemainingWaits = 0>
__forceinline__ __device__ void tma_store_wait() {
// NOTES: this function does not have `.read`
asm volatile("cp.async.bulk.wait_group %0;" ::"n"(kNumRemainingWaits) : "memory");
}
CUTLASS_DEVICE
void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier& mbarrier,
void* smem_ptr, const uint32_t& col_idx, const int4& row_idxs, const uint64_t& cache_hint) {
const auto smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
const auto mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbarrier_addr), "l"(cache_hint)
: "memory"
);
}
} // namespace deep_gemm::ptx

View File

@@ -0,0 +1,53 @@
#pragma once
#include <cuda/std/cstdint>
#include <cuda_bf16.h>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::ptx {
CUTLASS_DEVICE uint32_t get_sm_idx() {
uint32_t sm_idx;
asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx));
return sm_idx;
}
CUTLASS_DEVICE uint32_t get_lane_idx() {
uint32_t lane_id;
asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id));
return lane_id;
}
CUTLASS_DEVICE void sync_aligned(const uint32_t& num_threads, const uint32_t& barrier_idx) {
asm volatile("bar.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads));
}
CUTLASS_DEVICE void sync_unaligned(const uint32_t& num_threads, const uint32_t& barrier_idx) {
asm volatile("barrier.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads));
}
template <typename dtype_t>
CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) {
DG_STATIC_ASSERT(sizeof(dtype_t) % sizeof(uint32_t) == 0, "");
const auto send_int_values = reinterpret_cast<uint32_t*>(&ptr);
dtype_t recv_dtype;
auto recv_int_values = reinterpret_cast<uint32_t*>(&recv_dtype);
#pragma unroll
for (uint32_t i = 0; i < sizeof(dtype_t) / sizeof(uint32_t); ++ i)
recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], static_cast<int>(src_lane_idx));
return recv_dtype;
}
CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) {
#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)
// Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100
asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast<uint16_t*>(&b.x)));
asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast<uint16_t*>(&b.y)));
#else
const auto [x, y] = __bfloat1622float2(b);
a.x += x, a.y += y;
#endif
}
} // namespace deep_gemm::ptx

View File

@@ -0,0 +1,25 @@
#pragma once
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::ptx {
CUTLASS_DEVICE void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
CUTLASS_DEVICE void warpgroup_commit_batch() {
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
}
CUTLASS_DEVICE void warpgroup_fence_operand(float& reg) {
asm volatile("" : "+f"(reg) :: "memory");
}
template <int N>
CUTLASS_DEVICE void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
}
} // namespace deep_gemm::ptx

View File

@@ -0,0 +1,300 @@
#pragma once
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/types.cuh>
namespace deep_gemm::sched {
enum class IndexType {
MN,
K,
SF_K,
};
template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
static constexpr uint32_t get_num_1d_blocks_per_group() {
// Select the best from candidates
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
for (const auto candidate: {8u, 16u}) {
const auto usage = kIsMulticastOnA ?
candidate * BLOCK_N + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
candidate * BLOCK_M + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
if (usage < min_usage)
min_usage = usage, num_best_blocks = candidate;
}
return num_best_blocks;
}
#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 on SM90 (float SF), gran_k * 4 on SM100 (packed UE8M0 SF)
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
struct Scheduler {
int current_iter = -1;
// Block configs
uint32_t num_blocks;
uint32_t num_m_blocks;
uint32_t num_n_blocks;
// For SM90 multicast checks
uint32_t num_blocks_in_group;
bool is_peer_cta_alive = true;
// For grouped GEMM
int* grouped_layout;
uint32_t current_group_idx = 0;
// Only used for masked layout
uint32_t current_m_cumsum = 0;
// Only used for contiguous psum layout
uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
// Only used for k-grouped layout
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
uint32_t next_group_idx, next_shape_k;
// Only used for k-grouped gemm
CUTLASS_DEVICE void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
for (; group_idx < kNumGroups; ++ group_idx) {
shape_k = grouped_layout[group_idx];
if (shape_k > 0)
break;
}
}
// ReSharper disable once CppPossiblyUninitializedMember
CUTLASS_DEVICE explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n,
const uint32_t& shape_k, int* grouped_layout = nullptr) {
num_m_blocks = math::ceil_div(shape_m, BLOCK_M);
num_n_blocks = math::ceil_div(shape_n, BLOCK_N);
current_shape_k = shape_k;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
num_blocks = num_m_blocks * num_n_blocks;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
this->grouped_layout = grouped_layout;
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
this->grouped_layout = grouped_layout;
current_psum_m = grouped_layout[0];
num_m_blocks = math::ceil_div(current_psum_m, BLOCK_M);
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
get_next_k_group(current_group_idx, current_shape_k);
next_group_idx = current_group_idx + 1;
get_next_k_group(next_group_idx, next_shape_k);
}
}
CUTLASS_DEVICE void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
const auto primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
const auto secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
const auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
const auto group_idx = block_idx / num_blocks_per_group;
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
// Fix unaligned TMA multicast
// NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
// while SM100 uses 2-CTA, which can not be dynamically disabled
#if __CUDA_ARCH__ < 1000
if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
num_blocks_in_group = num_blocks_in_group ^ 1;
} else {
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
first_block_idx += num_blocks_in_group ^ 1;
num_blocks_in_group = 1;
}
}
#endif
// Convert to final M/N block indices
// `kIsMulticastOnA == true` leads to groups on N
if constexpr (kIsMulticastOnA) {
m_block_idx = in_group_idx / num_blocks_in_group;
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
} else {
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
n_block_idx = in_group_idx / num_blocks_in_group;
}
}
template <bool kWithGroupOffset, IndexType kIndexType = IndexType::MN>
CUTLASS_DEVICE uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
const auto offset = kWithGroupOffset ? cute::max(0, grouped_layout[m_block_idx * BLOCK_M]) : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
const auto offset = kWithGroupOffset ? current_group_idx : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
auto offset = 0;
if constexpr (kWithGroupOffset) {
if constexpr (kIndexType == IndexType::MN)
offset = current_group_idx * shape_dim;
else if constexpr (kIndexType == IndexType::K)
offset = current_k_cumsum;
else if constexpr (kIndexType == IndexType::SF_K)
offset = current_sf_k_cumsum;
}
return offset + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::Batched) {
// Ignore kWithGroupOffset, and apply offset for IndexType::SF_K
const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0;
return offset * shape_dim + block_idx * block_size;
}
}
// For swap A/B and psum layout only
CUTLASS_DEVICE uint32_t get_aligned_effective_m_in_block(const uint32_t& m_block_idx) const {
constexpr uint32_t UMMA_STEP_N = 16;
DG_STATIC_ASSERT(BLOCK_M % UMMA_STEP_N == 0, "Invalid alignment");
if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout)
return math::align(m_block_idx == last_psum_m / BLOCK_M + num_m_blocks - 1 ? current_psum_m - m_block_idx * BLOCK_M : BLOCK_M, UMMA_STEP_N);
return BLOCK_M;
}
CUTLASS_DEVICE bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x;
if constexpr (kGemmType == GemmType::MGroupedMasked) {
while (true) {
// End of the task
if (current_group_idx == kNumGroups)
return false;
// Within current group
num_m_blocks = math::ceil_div(static_cast<uint32_t>(grouped_layout[current_group_idx]), BLOCK_M);
const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * num_n_blocks)
break;
// Move to check the next group
current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
while (true) {
// Within current group
if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
break;
// Move to check the next group
if (++ current_group_idx == kNumGroups)
return false;
// NOTES: `num_m_blocks` varies with the increase of the group index
last_psum_m = math::align(current_psum_m, BLOCK_M);
current_psum_m = grouped_layout[current_group_idx];
current_m_block_cumsum += num_m_blocks;
num_m_blocks = math::ceil_div(current_psum_m - last_psum_m, BLOCK_M);
}
get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
// NOTES: `last_psum_m` is aligned with block M
m_block_idx += last_psum_m / BLOCK_M;
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
while (true) {
// End of the task
if (current_group_idx == kNumGroups)
return false;
// Within current group
if (next_block_idx < (current_num_valid_groups + 1) * num_blocks)
break;
// Move to check the next group
current_k_cumsum += current_shape_k;
current_sf_k_cumsum += math::ceil_div(current_shape_k, SF_K_ALIGNMENT);
current_num_valid_groups ++;
current_group_idx = next_group_idx ++;
current_shape_k = next_shape_k;
get_next_k_group(next_group_idx, next_shape_k);
}
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_blocks, m_block_idx, n_block_idx);
} else if constexpr (kGemmType == GemmType::Batched) {
if (next_block_idx >= num_blocks * kNumGroups)
return false;
current_group_idx = next_block_idx / num_blocks;
const auto block_idx = next_block_idx - current_group_idx * num_blocks;
if constexpr (kIsMulticastOnA) {
m_block_idx = block_idx / num_n_blocks;
n_block_idx = block_idx % num_n_blocks;
} else {
m_block_idx = block_idx % num_m_blocks;
n_block_idx = block_idx / num_m_blocks;
}
} else {
if (next_block_idx >= num_blocks)
return false;
// For SM90 only
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass)
num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
// For SM90 only
CUTLASS_DEVICE bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or
kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
if constexpr (kIsMulticastOnA) {
return true;
} else {
const auto group_idx = grouped_layout[m_block_idx * BLOCK_M];
const auto peer_group_idx = grouped_layout[(m_block_idx ^ 1) * BLOCK_M];
return group_idx == peer_group_idx;
}
}
}
// For SM90 only
// ReSharper disable once CppNotAllPathsReturnValue
CUTLASS_DEVICE bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
return true;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
return grouped_layout[m_offset + m_block_idx * BLOCK_M] >= 0;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
return m_offset + m_block_idx * BLOCK_M < grouped_layout[current_group_idx];
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
return m_offset + m_block_idx * BLOCK_M < current_psum_m;
} else {
// Unreachable
DG_TRAP_ONLY_DEVICE_ASSERT(false);
}
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm::sched

View File

@@ -0,0 +1,221 @@
#pragma once
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/layout/mega_moe.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm::sched {
// Computation phase for the current block
enum class BlockPhase {
None = 0,
Linear1 = 1,
Linear2 = 2
};
template <uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t L1_SHAPE_N, uint32_t L1_SHAPE_K,
uint32_t L2_SHAPE_N, uint32_t L2_SHAPE_K,
uint32_t kNumExpertsPerRank,
uint32_t kNumExpertsPerWave,
uint32_t kNumSMs, uint32_t kNumRanks,
uint32_t kNumExpertsPerLane = math::constexpr_ceil_div(kNumExpertsPerRank, 32u),
uint32_t kNumL1BlockNs = L1_SHAPE_N / BLOCK_N,
uint32_t kNumL2BlockNs = L2_SHAPE_N / BLOCK_N,
uint32_t kNumL1BlockKs = L1_SHAPE_K / BLOCK_K,
uint32_t kNumL2BlockKs = L2_SHAPE_K / BLOCK_K>
struct MegaMoEScheduler {
DG_STATIC_ASSERT(L1_SHAPE_N % BLOCK_N == 0, "Invalid shape");
DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid shape");
DG_STATIC_ASSERT(L1_SHAPE_K % BLOCK_K == 0, "Invalid shape");
DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid shape");
DG_STATIC_ASSERT(kNumExpertsPerRank % kNumExpertsPerWave == 0, "Invalid wave config");
// NOTES: N block counts must be even so that 2 adjacent CTAs in a cluster
// always land on the same m_block_idx with n_block_idx differing by 1
DG_STATIC_ASSERT(kNumSMs % 2 == 0, "Number of SMs must be even for 2-CTA cluster");
DG_STATIC_ASSERT(kNumL1BlockNs % 2 == 0, "L1 N block count must be even for 2-CTA cluster");
DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster");
// Arrival counts
const layout::Workspace& workspace;
// Scheduler state
BlockPhase next_phase = BlockPhase::Linear1;
// Current expert and block indices
uint32_t current_local_expert_idx = 0;
uint32_t current_num_tokens = 0;
uint32_t current_pool_block_offset = 0;
uint32_t block_idx = 0;
uint32_t m_block_idx = 0;
uint32_t n_block_idx = 0;
// Pre-cached per-expert token counts (filled during `for_each_block` init)
// Layout: `stored_num_tokens_per_expert[i]` holds expert (i * 32 + lane_idx)'s count
uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {};
CUTLASS_DEVICE explicit MegaMoEScheduler(const layout::Workspace& workspace): workspace(workspace) {
block_idx = blockIdx.x;
}
CUTLASS_DEVICE uint32_t get_wave_expert_end_idx() const {
return math::align(current_local_expert_idx + 1, kNumExpertsPerWave);
}
CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const {
uint32_t valid_value;
#pragma unroll
for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) {
valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ?
stored_num_tokens_per_expert[i] : valid_value;
}
return ptx::exchange(valid_value, expert_idx % 32);
}
// Get pool block offset for a given expert index from a per-lane token count array
CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) {
uint32_t num_blocks = 0;
#pragma unroll
for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) {
if (i * 32 + ptx::get_lane_idx() < expert_idx)
num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M);
}
return __reduce_add_sync(0xffffffff, num_blocks);
}
CUTLASS_DEVICE void advance_expert_idx() {
current_pool_block_offset += get_current_num_m_blocks();
current_local_expert_idx += 1;
current_num_tokens = get_num_tokens(current_local_expert_idx);
}
CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) {
current_local_expert_idx = expert_idx;
current_num_tokens = get_num_tokens(expert_idx);
current_pool_block_offset = get_pool_block_offset(expert_idx);
}
CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const {
return current_pool_block_offset;
}
CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const {
return math::ceil_div(current_num_tokens, BLOCK_M);
}
template <bool kDoUMMAAligned = false>
CUTLASS_DEVICE uint32_t get_valid_m() const {
const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M);
return kDoUMMAAligned ? math::align(m, 16u) : m;
}
CUTLASS_DEVICE bool fetch_next_l1_block() {
const auto wave_end_expert_idx = get_wave_expert_end_idx();
while (current_local_expert_idx < wave_end_expert_idx) {
const auto num_m_blocks = get_current_num_m_blocks();
m_block_idx = block_idx / kNumL1BlockNs;
if (m_block_idx < num_m_blocks)
return true;
// Current expert is fully assigned, move to the next
block_idx -= num_m_blocks * kNumL1BlockNs;
advance_expert_idx();
}
return false;
}
CUTLASS_DEVICE bool fetch_next_l2_block() {
const auto wave_end_expert_idx = get_wave_expert_end_idx();
while (current_local_expert_idx < wave_end_expert_idx) {
const auto num_m_blocks = get_current_num_m_blocks();
if (block_idx < num_m_blocks * kNumL2BlockNs) {
m_block_idx = block_idx / kNumL2BlockNs;
return true;
}
// Current expert is fully assigned, move to the next
block_idx -= num_m_blocks * kNumL2BlockNs;
advance_expert_idx();
}
return false;
}
// Core state machine: assigns the next block
CUTLASS_DEVICE cute::tuple<BlockPhase, uint32_t, uint32_t, uint32_t> get_next_block() {
while (true) {
if (current_local_expert_idx >= kNumExpertsPerRank)
break;
if (next_phase == BlockPhase::Linear1) {
if (fetch_next_l1_block()) {
// Found a new L1 block
n_block_idx = block_idx - m_block_idx * kNumL1BlockNs;
// Jump to next block
block_idx += kNumSMs;
return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx};
} else {
// L1 for the current wave is complete, transition to L2
next_phase = BlockPhase::Linear2;
set_expert_idx(math::align<uint32_t, false>(current_local_expert_idx - 1, kNumExpertsPerWave));
}
} else {
if (fetch_next_l2_block()) {
// Found a new L2 block
n_block_idx = block_idx - m_block_idx * kNumL2BlockNs;
// Jump to next block
block_idx += kNumSMs;
return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx};
} else {
// Move to L1 of the next wave
next_phase = BlockPhase::Linear1;
}
}
}
// All waves and experts are fully processed
return {BlockPhase::None, 0, 0, 0};
}
CUTLASS_DEVICE void fetch_expert_recv_count() {
// NOTES: each lane caches experts at indices (i * 32 + lane_idx)
#pragma unroll
for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) {
const auto expert_idx = i * 32 + ptx::get_lane_idx();
uint64_t value = 0;
if (expert_idx < kNumExpertsPerRank) {
do {
value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx));
} while (static_cast<uint32_t>(value >> 32) != kNumSMs * kNumRanks);
}
stored_num_tokens_per_expert[i] = static_cast<uint32_t>(value);
}
__syncwarp();
}
template <typename Func>
CUTLASS_DEVICE void for_each_block(Func&& func) {
// Wait for all expert counters to be finalized
fetch_expert_recv_count();
// Initialize current expert with 0
set_expert_idx(0);
// Iterate over all blocks
// TODO: add swizzle within expert waves for better L2 cache utilization
while (true) {
CUTE_TIE_DECL(get_next_block(), block_phase, current_local_expert_idx, m_block_idx, n_block_idx);
if (block_phase == BlockPhase::None)
break;
func(block_phase, current_local_expert_idx,
block_phase == BlockPhase::Linear2 ? kNumL2BlockKs : kNumL1BlockKs,
m_block_idx, n_block_idx);
}
}
};
} // namespace deep_gemm::sched

View File

@@ -0,0 +1,114 @@
#pragma once
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm::sched {
template <uint32_t kAlignedBatchSize, uint32_t SPLIT_KV, uint32_t kNumSMs>
CUTLASS_GLOBAL __launch_bounds__(32, 1)
void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d,
const uint32_t* context_lens, uint32_t* schedule_metadata) {
DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size");
const uint32_t lane_idx = ptx::get_lane_idx();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
uint32_t num_segs[kAlignedBatchSize / 32];
#pragma unroll
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
const uint32_t q_idx = k * 32 + lane_idx;
const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx);
const uint32_t context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0);
num_segs[k] = math::ceil_div(context_len, SPLIT_KV);
}
__shared__ uint32_t prefix_sum[kAlignedBatchSize];
uint32_t sum = 0;
#pragma unroll
for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) {
uint32_t x = num_segs[k];
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
const uint32_t y = __shfl_up_sync(0xffffffff, x, offset);
x += (lane_idx >= offset ? y : 0);
}
x += sum;
prefix_sum[k * 32 + lane_idx] = x;
sum = __shfl_sync(0xffffffff, x, 31);
}
const uint32_t num_next_n_atoms = next_n / ((next_n % 2 == 0) ? 2 : 1);
const uint32_t total = sum * num_next_n_atoms;
const uint32_t q = total / kNumSMs, r = total % kNumSMs;
for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) {
uint32_t seg_starts = sm_idx * q + min(sm_idx, r);
uint32_t q_idx = 0;
while (q_idx < batch_size and prefix_sum[q_idx] * num_next_n_atoms <= seg_starts)
++ q_idx;
const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms);
const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]);
const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0;
const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0;
const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx;
__syncwarp();
schedule_metadata[sm_idx * 2] = q_atom_idx;
schedule_metadata[sm_idx * 2 + 1] = kv_split_idx;
}
}
template <uint32_t kNextN, bool kIsContextLens2D,
uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit,
uint32_t kNumNextNAtoms>
struct PagedMQALogitsScheduler {
const uint32_t* context_lens;
uint32_t current_q_atom_idx, current_kv_idx;
uint32_t end_q_atom_idx, end_kv_idx;
uint32_t current_num_kv;
CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const {
const uint32_t q_idx = q_atom_idx / kNumNextNAtoms;
const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx);
return math::ceil_div(context_lens[lens_idx], BLOCK_KV);
}
CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t* context_lens, const uint32_t* schedule_meta) {
this->context_lens = context_lens;
const auto current_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx];
const auto end_pack = reinterpret_cast<const uint2*>(schedule_meta)[sm_idx + 1];
current_q_atom_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit;
end_q_atom_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit;
current_num_kv = get_num_kv(current_q_atom_idx);
}
CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) {
q_atom_idx = current_q_atom_idx;
kv_idx = current_kv_idx;
num_kv = current_num_kv;
if (current_q_atom_idx == end_q_atom_idx and current_kv_idx == end_kv_idx)
return false;
current_kv_idx += kNumBlocksPerSplit;
if (current_kv_idx >= current_num_kv) {
++ current_q_atom_idx;
current_kv_idx = 0;
if (current_q_atom_idx % kNumNextNAtoms == 0 and exist_q_atom_idx(current_q_atom_idx)) {
current_num_kv = get_num_kv(current_q_atom_idx);
}
}
return true;
}
CUTLASS_DEVICE bool exist_q_atom_idx(const uint32_t& q_atom_idx) const {
return q_atom_idx < end_q_atom_idx or (q_atom_idx == end_q_atom_idx and 0 < end_kv_idx);
}
};
} // namespace deep_gemm::sched

View File

@@ -47,7 +47,7 @@ def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr,
# Compute
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64)
k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
k_mask = k_range < K
a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :]
b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1)

View File

@@ -50,7 +50,7 @@ def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr,
# Compute
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(k_start, k_end, BLOCK_SIZE_K):
k_range = (k + tl.arange(0, BLOCK_SIZE_K)).to(tl.int64)
k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
rows = tl.load(k_indices_ptr + k_range).to(tl.int64)
a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M
b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :]

128
deep_gemm/mega/__init__.py Normal file
View File

@@ -0,0 +1,128 @@
import torch
from typing import Tuple, Optional
from ..utils.math import align
# noinspection PyBroadException
try:
# noinspection PyProtectedMember
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed as dist
except Exception as exception:
print(f'Failed to load mega kernels, please check your PyTorch version: {exception}')
from .. import _C
class SymmBuffer:
def __init__(self, group: dist.ProcessGroup,
# MoE arguments
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu'):
self.group = group
self.num_experts = num_experts
self.num_max_tokens_per_rank = num_max_tokens_per_rank
self.num_topk = num_topk
self.hidden = hidden
self.intermediate_hidden = intermediate_hidden
# Allocate a symmetric buffer
num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe(
group.size(), num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda')
self.handle = symm_mem.rendezvous(self.buffer, group=group)
self.buffer.zero_()
self.group.barrier()
torch.cuda.synchronize()
# Create input buffer views
(self.x, self.x_sf,
self.topk_idx, self.topk_weights,
self.l1_acts, self.l1_acts_sf,
self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer)
def destroy(self):
self.handle = None
self.buffer = None
self.group = None
self.x = None
self.x_sf = None
def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup,
num_experts: int,
num_max_tokens_per_rank: int, num_topk: int,
hidden: int, intermediate_hidden: int,
use_fp8_dispatch: bool = True,
activation: str = 'swiglu') -> SymmBuffer:
# Token count must be aligned to block m
num_ranks = group.size()
block_m = _C.get_block_m_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk)
num_max_tokens_per_rank = align(num_max_tokens_per_rank, block_m)
return SymmBuffer(
group, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
use_fp8_dispatch, activation
)
def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
# [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up]
def interleave(t, gran: int = 8) -> torch.Tensor:
g, n, *rest = t.shape
half = n // 2
gate = t[:, :half].reshape(g, half // gran, gran, *rest)
up = t[:, half:].reshape(g, half // gran, gran, *rest)
return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest))
return interleave(l1_weights[0]), interleave(l1_weights[1])
def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor:
num_groups, mn, packed_sf_k = sf.shape
assert sf.dtype == torch.int and mn % 128 == 0
result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k)
.transpose(2, 3)
.reshape(num_groups, mn, packed_sf_k))
return torch.empty_like(sf).copy_(result)
def transform_weights_for_mega_moe(
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
# L1: interleave gate/up, then transpose SF for UTCCP
l1_interleaved = _interleave_l1_weights(l1_weights)
l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1]))
# L2: only transpose SF for UTCCP
l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1]))
return l1_weights, l2_weights
def fp8_fp4_mega_moe(y: torch.Tensor,
l1_weights: Tuple[torch.Tensor, torch.Tensor],
l2_weights: Tuple[torch.Tensor, torch.Tensor],
sym_buffer: SymmBuffer,
recipe: Tuple[int, int, int] = (1, 1, 32),
activation: str = 'swiglu',
activation_clamp: Optional[float] = None,
fast_math: bool = True):
_C.fp8_fp4_mega_moe(
y,
l1_weights, l2_weights,
sym_buffer.buffer,
sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(),
sym_buffer.num_max_tokens_per_rank,
sym_buffer.num_experts, sym_buffer.num_topk,
recipe,
activation, activation_clamp,
fast_math
)

View File

@@ -1,6 +1,7 @@
import os
import sys
import torch
from typing import Callable, Optional
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
@@ -78,7 +79,8 @@ class suppress_stdout_stderr:
def bench_kineto(fn, kernel_names, num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: str = None, flush_l2: bool = True,
with_multiple_kernels: bool = False):
with_multiple_kernels: bool = False,
barrier: Optional[Callable] = None):
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
@@ -96,14 +98,21 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
profiler = torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule, acc_events=True)
with profiler:
for i in range(2):
for _ in range(num_tests):
if flush_l2:
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
if barrier is not None:
# NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead
# noinspection PyProtectedMember
torch.cuda._sleep(int(2e7)) # ~10ms
barrier()
fn()
torch.cuda.synchronize()
profiler.step()
# Parse the profiling table
@@ -111,7 +120,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
if not with_multiple_kernels:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}'
# Save chrome traces
if trace_path is not None:

View File

@@ -1,3 +1,4 @@
from . import math, layout
from .layout import *
from .math import *
from .dist import init_dist, uneven_all_gather

74
deep_gemm/utils/dist.py Normal file
View File

@@ -0,0 +1,74 @@
import inspect
import os
import torch
import torch.distributed as dist
from typing import Tuple
_local_rank = None
def init_dist(local_rank: int, num_local_ranks: int) -> Tuple[int, int, dist.ProcessGroup]:
# NOTES: you may rewrite this function with your own cluster settings
ip = os.getenv('MASTER_ADDR', '127.0.0.1')
port = int(os.getenv('MASTER_PORT', '8361'))
num_nodes = int(os.getenv('WORLD_SIZE', 1))
node_rank = int(os.getenv('RANK', 0))
# Set local rank
global _local_rank
_local_rank = local_rank
sig = inspect.signature(dist.init_process_group)
params = {
'backend': 'nccl',
'init_method': f'tcp://{ip}:{port}',
'world_size': num_nodes * num_local_ranks,
'rank': node_rank * num_local_ranks + local_rank,
}
if 'device_id' in sig.parameters:
# noinspection PyTypeChecker
params['device_id'] = torch.device(f'cuda:{local_rank}')
dist.init_process_group(**params)
torch.set_default_device('cuda')
torch.cuda.set_device(local_rank)
return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes)))
def uneven_all_gather(tensor: torch.Tensor, dim: int = 0, group: dist.ProcessGroup = None) -> torch.Tensor:
world_size = dist.get_world_size(group)
# Exchange sizes
local_dim_size = torch.tensor([tensor.shape[dim]], device=tensor.device, dtype=torch.long)
all_dim_sizes = [torch.zeros_like(local_dim_size) for _ in range(world_size)]
dist.all_gather(all_dim_sizes, local_dim_size, group=group)
all_dim_sizes = [s.item() for s in all_dim_sizes]
max_dim_size = max(all_dim_sizes)
# Pad
if tensor.shape[dim] < max_dim_size:
pad_shape = list(tensor.shape)
pad_shape[dim] = max_dim_size - tensor.shape[dim]
padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
tensor_padded = torch.cat([tensor, padding], dim=dim)
else:
tensor_padded = tensor.contiguous()
# All-gather
gathered = [torch.zeros_like(tensor_padded) for _ in range(world_size)]
dist.all_gather(gathered, tensor_padded, group=group)
# Remove padding
trimmed = [
torch.narrow(gathered[i], dim, 0, all_dim_sizes[i])
for i in range(world_size)
]
return torch.cat(trimmed, dim=dim)
def dist_print(s: str = '', once_in_node: bool = False) -> None:
global _local_rank
assert _local_rank is not None
if not once_in_node or _local_rank == 0:
print(s, flush=True)
dist.barrier()

View File

@@ -10,7 +10,11 @@ except ImportError:
pass
# Valid for all CUDA versions
from .._C import get_mk_alignment_for_contiguous_layout
from .._C import (
set_mk_alignment_for_contiguous_layout,
get_mk_alignment_for_contiguous_layout,
get_theoretical_mk_alignment_for_contiguous_layout,
)
# Some alias
get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout

View File

@@ -11,21 +11,30 @@ def align(x: int, y: int) -> int:
def ceil_to_ue8m0(x: torch.Tensor):
assert x.view(-1).amax().item() > 0
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
bits = x.abs().float().view(torch.int)
exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int()
return (exp.clamp(1, 254) << 23).view(torch.float)
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
def pack_ue8m0_to_int(x: torch.Tensor):
assert x.dtype == torch.float and x.size(-1) % 4 == 0
assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all()
return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int)
def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128,
use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
padded_n = align(n, gran_k)
x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0)
x_padded[:, :n] = x
x_view = x_padded.view(m, -1, gran_k)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
x_view = x_padded.view(m, padded_n // gran_k, gran_k)
x_amax = x_view.abs().float().amax(dim=2).view(m, padded_n // gran_k).clamp(1e-4)
sf = x_amax / 448.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf
x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous()
return x_fp8, pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf
def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -70,13 +79,14 @@ def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
code = idx.to(torch.uint8)
sign = (x < 0) & (idx != 0)
code = code | (sign.to(torch.uint8) << 3)
return code # uint8, 0..15
return code.view(torch.int8)
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128,
use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
m, n = x.shape
assert n % 2 == 0
assert not use_packed_ue8m0 or use_ue8m0
padded_n = align(n, gran_k)
x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device)
x_padded[:, :n] = x
@@ -85,23 +95,49 @@ def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -
sf = x_amax / 6.0
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
x_scaled = x_view * (1.0 / sf.unsqueeze(2))
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n)
codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # int8, (m, padded_n)
codes2 = codes.view(m, padded_n // 2, 2)
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8
return packed[:, :n // 2].contiguous(), sf
packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # int8
return packed[:, :n // 2].contiguous(), pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf
def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor:
assert a.dtype == torch.uint8
assert a.dtype == torch.int8
assert a.dim() == 2
m, n2 = a.shape
n = n2 * 2
assert (m % 2) == 0
lo = a & 0x0F
hi = (a >> 4) & 0x0F
codes = torch.empty((m, n), device=a.device, dtype=torch.uint8)
codes = torch.empty((m, n), device=a.device, dtype=torch.int8)
codes[:, 0::2], codes[:, 1::2] = lo, hi
codes_t = codes.transpose(0, 1).contiguous()
codes2 = codes_t.view(n, m // 2, 2)
out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4)
return out.contiguous()
return out.contiguous()
def _dequantize_from_fp4_e2m1(x: torch.Tensor) -> torch.Tensor:
fp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=x.device, dtype=torch.float)
sign, value_idx = (x & 0x08) != 0, (x & 0x07).to(torch.int)
value = fp4_values[value_idx]
return torch.where(sign & (value_idx != 0), -value, value)
def unpack_ue8m0_from_int(packed_sf: torch.Tensor) -> torch.Tensor:
return (packed_sf.view(torch.uint8).to(torch.int) << 23).view(torch.float)
def cast_back_from_fp4(packed: torch.Tensor, sf: torch.Tensor, gran_k: int = 128,
use_packed_ue8m0: bool = False) -> torch.Tensor:
m, n2 = packed.shape
n = n2 * 2
if use_packed_ue8m0:
sf = unpack_ue8m0_from_int(sf)
unpacked = torch.zeros((m, n), dtype=torch.int8, device=packed.device)
unpacked[:, ::2] = packed & 0x0F
unpacked[:, 1::2] = (packed >> 4) & 0x0F
x_dequantized = _dequantize_from_fp4_e2m1(unpacked)
group_idx = torch.arange(n, device=packed.device) // gran_k
x_restored = x_dequantized * sf[:, group_idx]
return x_restored

View File

@@ -68,7 +68,7 @@ def get_package_version():
cmd = ['git', 'rev-parse', '--short', 'HEAD']
revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except (subprocess.CalledProcessError, FileNotFoundError, OSError):
except Exception:
revision = '+local'
return f'{public_version}{revision}'
@@ -172,6 +172,7 @@ class CachedWheelsCommand(_bdist_wheel):
wheel_url, wheel_filename = get_wheel_url()
print(f'Try to download wheel from URL: {wheel_url}')
# noinspection PyBroadException
try:
with urllib.request.urlopen(wheel_url, timeout=1) as response:
with open(wheel_filename, 'wb') as out_file:

View File

@@ -8,7 +8,8 @@ from deep_gemm.utils import (
align, ceil_div,
per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8,
per_token_cast_to_fp4, transpose_packed_fp4,
get_mk_alignment_for_contiguous_layout
get_mk_alignment_for_contiguous_layout,
set_mk_alignment_for_contiguous_layout
)
@@ -107,7 +108,7 @@ def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator:
def get_psum_layout_usage() -> tuple:
return (False, True) if get_arch_major() == 10 else (False, )
return True, False
def enumerate_normal(dtype: torch.dtype) -> Generator:
@@ -168,7 +169,7 @@ def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator:
def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
quant_config_list = QuantConfig.get_list_from_dtype(dtype)
max_m = 4096
m_group_list = [(6, 1024), (32, 192), (32, 50)]
m_group_list = [(32, 192), (6, 1024), (32, 20), (6, 20)]
n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)]
for kernel_type in get_kernel_types(dtype):
for quant_config in quant_config_list:
@@ -182,6 +183,7 @@ def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator:
def enumerate_k_grouped_contiguous(dtype: torch.dtype):
gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128)
# Only K-major is supported for SM90 FP8
major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 and dtype == torch.float8_e4m3fn \
else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
@@ -189,26 +191,36 @@ def enumerate_k_grouped_contiguous(dtype: torch.dtype):
for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64
( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32
(16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group
if dtype == torch.bfloat16:
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)]
yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group
else:
for gran_k in gran_k_list:
set_mk_alignment_for_contiguous_layout(gran_k)
ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), gran_k) for _ in range(num_groups)]
yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group, gran_k
def enumerate_sf_layout():
gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128)
for use_ue8m0 in (False, True):
for with_transpose in (True, False):
for mn in (4096, 4097, 8192):
for k in (128, 7168, 7296):
for num_groups in (1, 2, 4):
yield mn, k, with_transpose, use_ue8m0, num_groups
for gran_k in gran_k_list:
set_mk_alignment_for_contiguous_layout(gran_k)
yield mn, k, with_transpose, use_ue8m0, num_groups, gran_k
def enumerate_k_grouped_sf_layout():
alignment = get_mk_alignment_for_contiguous_layout()
assert alignment % 128 == 0
gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128)
for mn in (4096, 7168):
for num_groups, avg_k in ((16, 2048), (8, 4096), (72, 384), (128, 256)):
ks = [align(int(random.uniform(0.7, 1.3) * avg_k), alignment) for _ in range(num_groups)]
yield mn, ks, num_groups
for gran_k in gran_k_list:
set_mk_alignment_for_contiguous_layout(gran_k)
ks = [align(int(random.uniform(0.7, 1.3) * avg_k), gran_k) for _ in range(num_groups)]
yield mn, ks, num_groups, gran_k
def enumerate_transpose():
@@ -222,25 +234,24 @@ def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
if is_fp4:
x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1])
return x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1])
else:
x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1])
return x
return x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1])
def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool,
use_ue8m0: bool, use_block_cast_for_fp8: bool = False):
num_groups, mn, k = x.size()
if is_fp4:
x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else \
torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8),
x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.int8) if major.is_k_major() else \
torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.int8),
torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float))
for i in range(num_groups):
x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1])
x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1])
return x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1])
else:
x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn),
torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \
@@ -248,8 +259,7 @@ def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k:
for i in range(num_groups):
x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \
else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k)
x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1])
return x
return x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1])
def generate_normal(m: int, n: int, k: int,
@@ -325,7 +335,7 @@ def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor):
last_psum_m = 0
for i in range(num_groups):
x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m]
last_psum_m = align(psum_m[i], 128)
last_psum_m = align(psum_m[i], get_mk_alignment_for_contiguous_layout())
return x_psum
@@ -342,7 +352,7 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group:
psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int)
for j in range(num_groups):
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j]
psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout())) + masked_m[j]
assert masked_m.amax().item() <= max_m
if use_bf16:
@@ -356,8 +366,8 @@ def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group:
def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int],
use_ue8m0: bool = False, use_bf16: bool = False):
assert get_mk_alignment_for_contiguous_layout() % 128 == 0
use_ue8m0: bool = False, use_bf16: bool = False, gran_k = 128):
assert get_mk_alignment_for_contiguous_layout() % gran_k == 0
k = sum(ks)
a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16)
@@ -376,8 +386,8 @@ def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: Majo
assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor)
return k, a, b, c, d, ref_d
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0)
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0)
a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0, gran_k=gran_k)
b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0, gran_k=gran_k)
# Transpose for K Major A/B
if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor):

View File

@@ -10,9 +10,9 @@ from deep_gemm.testing import (
ignore_env, get_arch_major,
test_filter
)
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8
from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8, per_token_cast_to_fp4, cast_back_from_fp4
from generators import generate_normal, get_ue8m0_usage, get_kernel_types, MajorTypeAB
from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, reset_seed, MajorTypeAB
def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]):
@@ -53,40 +53,14 @@ def test_gemm_skip_head_mid() -> None:
assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}'
t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast),
'fp8_gemm', suppress_kineto_output=True)
'gemm_', suppress_kineto_output=True)
print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): '
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s')
f'{t * 1e6:4.0f} us | '
f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | '
f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s')
print()
def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor:
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8)
x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(dtype=torch.uint8)
x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(dtype=torch.uint8)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4)
def generate_cp_test_data(seq_len, seq_len_kv):
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
chunk_size = seq_len // 2
cp_size = seq_len_kv // seq_len
# Select an arbitrary CP rank
cp_id = cp_size // 3
ks = torch.zeros(seq_len, dtype=torch.int, device='cuda')
ke = torch.zeros(seq_len, dtype=torch.int, device='cuda')
for i in range(chunk_size):
ke[i] = cp_id * chunk_size + i
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
return ks, ke
def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False):
seq_len_kv = kv.shape[0]
@@ -113,92 +87,137 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
return logits, cost
@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 10)
def test_mqa_logits():
# Helper functions
def generate_ks_ke_tests(seq_len: int, seq_len_kv: int, disable_cp: bool):
if disable_cp:
ks = torch.zeros(seq_len, dtype=torch.int, device='cuda')
ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len)
return ks, ke
assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0
chunk_size = seq_len // 2
cp_size = seq_len_kv // seq_len
# Select an arbitrary CP rank
cp_id = cp_size // 3
ks = torch.zeros(seq_len, dtype=torch.int, device='cuda')
ke = torch.zeros(seq_len, dtype=torch.int, device='cuda')
for i in range(chunk_size):
ke[i] = cp_id * chunk_size + i
ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i
return ks, ke
def enumerate_mqa_logits():
for is_fp4 in ((True, False) if get_arch_major() == 10 else (False, )):
for logits_dtype in (torch.float, torch.bfloat16):
for compressed_logits, clean_logits in [(False, True), (True, False)]:
for seq_len in (2048, 4096):
for seq_len_kv in (4096, 8192):
for num_heads, head_dim in [(64, 128)]:
for disable_cp in (False, True):
yield is_fp4, logits_dtype, compressed_logits, clean_logits, seq_len, seq_len_kv, num_heads, head_dim, disable_cp
print('Testing FP8 MQA Logits:')
num_heads, head_dim = 64, 128
for seq_len in (2048, 4096):
for compressed_logits in (False, True):
for seq_len_kv in (4096, 8192):
for disable_cp in (False, True):
q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16)
kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16)
weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32)
for is_fp4, logits_dtype, compressed_logits, clean_logits, seq_len, seq_len_kv, num_heads, head_dim, disable_cp in enumerate_mqa_logits():
# Generate random inputs
q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16)
kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16)
weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32)
ks, ke = generate_ks_ke_tests(seq_len, seq_len_kv, disable_cp)
if disable_cp:
ks = torch.zeros(seq_len, dtype=torch.int, device='cuda')
ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len)
else:
ks, ke = generate_cp_test_data(seq_len, seq_len_kv)
# Calculate reference logits
ref_logits, ref_cost = ref_fp8_mqa_logits(q, kv, weights, ks, ke)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0, ), False)
# Quantize Q and KV to FP4 / FP8
if is_fp4:
q_fp4 = per_token_cast_to_fp4(q.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
q_in = (q_fp4[0].view(seq_len, num_heads, head_dim // 2), q_fp4[1].view(seq_len, num_heads))
q_simulated = cast_back_from_fp4(q_fp4[0], q_fp4[1], gran_k=32, use_packed_ue8m0=True).view(seq_len, num_heads, head_dim).to(torch.bfloat16)
if compressed_logits:
max_seqlen_k = (ke - ks).max().item()
logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False)
assert logits.size() == (seq_len, max_seqlen_k)
tmp = torch.full((seq_len, seq_len_kv), float('-inf'), device='cuda')
for i in range(seq_len):
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
logits = tmp
else:
logits = deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke)
kv_fp4 = per_token_cast_to_fp4(kv.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
kv_in = (kv_fp4[0].view(seq_len_kv, head_dim // 2), kv_fp4[1].view(seq_len_kv))
kv_simulated = cast_back_from_fp4(kv_fp4[0], kv_fp4[1], gran_k=32, use_packed_ue8m0=True).view(seq_len_kv, head_dim).to(torch.bfloat16)
else:
q_in = q.to(torch.float8_e4m3fn), None
q_simulated = q_in[0].to(torch.bfloat16)
kv_in = per_custom_dims_cast_to_fp8(kv, (0, ), False)
kv_simulated = (kv_in[0].float() * kv_in[1].unsqueeze(1)).to(torch.bfloat16)
do_check = (seq_len_kv < 32768)
if do_check:
ref_logits, ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
# Calculate reference logits
simulated_logits, _ = ref_fp8_mqa_logits(q_simulated, kv_simulated, weights, ks, ke)
ref_neginf_mask = (ref_logits == float('-inf'))
neginf_mask = (logits == float('-inf'))
assert torch.equal(neginf_mask, ref_neginf_mask)
# Prepare kwargs
kernel_kwargs = dict(
q=q_in, kv=kv_in, weights=weights,
cu_seq_len_k_start=ks, cu_seq_len_k_end=ke,
clean_logits=clean_logits, max_seqlen_k=0,
logits_dtype=logits_dtype
)
if compressed_logits:
max_seqlen_k = (ke - ks).max().item()
kernel_kwargs['max_seqlen_k'] = max_seqlen_k
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f'{diff=}'
else:
ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True)
# Run kernel
logits = deep_gemm.fp8_fp4_mqa_logits(**kernel_kwargs)
tflops = 2 * ref_cost * num_heads * head_dim / 1e12
if compressed_logits:
t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke, max_seqlen_k=max_seqlen_k, clean_logits=False), 'fp8_mqa_logits')
else:
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke), ('fp8_mqa_logits', 'clean_logits'))
clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke)
print(f' > S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: '
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, '
f'{(count_bytes(q_fp8, kv_fp8, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s', end='')
# noinspection PyUnboundLocalVariable
print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not compressed_logits else '')
# Post process for compressed logits
if compressed_logits:
assert logits.size() == (seq_len, max_seqlen_k)
tmp = torch.full((seq_len, seq_len_kv), float('-inf'), device='cuda')
for i in range(seq_len):
tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]]
logits = tmp
# Validation
ref_neginf_mask = (ref_logits == float('-inf'))
neginf_mask = (logits == float('-inf'))
assert torch.equal(neginf_mask, ref_neginf_mask)
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
simulated_logits = simulated_logits.masked_fill(ref_neginf_mask, 0)
logits = logits.masked_fill(ref_neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
simulated_diff = calc_diff(logits, simulated_logits)
assert diff < 0.02 if is_fp4 else 1e-3, f"Diff: {diff}"
assert simulated_diff < 5e-6, f"Simulated Diff: {simulated_diff}"
# Profiling
tflops = 2 * ref_cost * num_heads * head_dim / 1e12
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_mqa_logits(**kernel_kwargs), ('mqa_logits', 'clean_logits'))
clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke)
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: '
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, '
f'{(count_bytes(q_in, kv_in, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s', end='')
print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if clean_logits else '')
print()
def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor,
max_model_len: int, is_context_lens_2d: bool):
batch_size, next_n, heads, dim = q.size()
def ref_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor,
max_model_len: int, use_2d_context_lens: bool):
batch_size, next_n, num_heads, dim = q.size()
num_block, block_size, _, dim = kv_cache.size()
logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
context_lens = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens[i]
q_offsets = torch.full((next_n, ), context_len, device='cuda', dtype=torch.int32) if is_context_lens_2d \
else torch.arange(context_len - next_n, context_len, device='cuda')
q_offsets = torch.full((next_n, ), context_len, device='cuda', dtype=torch.int32) if use_2d_context_lens \
else torch.arange(context_len - next_n, context_len, device='cuda')
weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
num_blocks = (context_len + block_size - 1) // block_size
block_idxs = block_tables[i][:num_blocks]
kv_slice = kv_cache[block_idxs] # [num_blocks, block_size, kv_heads, dim]
kx = kv_slice.permute(2, 3, 0, 1).reshape(kv_slice.size(2), dim, -1) # [kv_heads, dim, total_tokens]
qx = q[i].transpose(0, 1) # q[i]: [next_n, heads, dim] -> [heads, next_n, dim]
s = torch.matmul(qx, kx).to(logits.dtype) # [heads, next_n, dim] @ [1, dim, total_tokens] -> [heads, next_n, total_tokens]
qx = q[i].transpose(0, 1) # q[i]: [next_n, num_heads, dim] -> [num_heads, next_n, dim]
s = torch.matmul(qx, kx).to(logits.dtype) # [num_heads, next_n, dim] @ [1, dim, total_tokens] -> [num_heads, next_n, total_tokens]
total_len = num_blocks * block_size
k_offsets = torch.arange(0, total_len, device=q.device)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
s = torch.where(mask[None, :, :], s, float('-inf')) # mask shape: [1, next_n, total_tokens]
s = torch.relu(s) * weight_slice[..., None] # weight_slice: [heads, next_n] -> [heads, next_n, 1]
s = torch.relu(s) * weight_slice[..., None] # weight_slice: [num_heads, next_n] -> [num_heads, next_n, 1]
s = s.sum(dim=0) # [next_n, total_tokens]
logits[i * next_n:(i + 1) * next_n, :total_len] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
@@ -206,70 +225,129 @@ def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor,
def test_paged_mqa_logits():
print('Testing FP8 Paged MQA Logits:')
max_model_len = 111 * 1000
for is_context_lens_2d in (False, True):
for batch_size, next_n in [(64, 1), (64, 2), (128, 1)]:
for heads, index_dim in [(64, 128)]:
for avg_kv in (8192, 32768):
num_blocks, blocksize = max_model_len * 3, 64
q = torch.randn((batch_size, next_n, heads, index_dim), device='cuda', dtype=torch.bfloat16)
kv_cache = torch.randn((num_blocks, blocksize, 1, index_dim), device='cuda', dtype=torch.bfloat16)
weights = torch.randn((batch_size * next_n, heads), device='cuda', dtype=torch.float32)
q_fp8 = q.to(torch.float8_e4m3fn)
kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache)
# Helper functions
def kv_cache_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1
x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4)
sf = x_amax / 448.0
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
x_cast_back = x_scaled.float() * sf
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size, )).cuda().to(torch.int32)
context_lens_list = context_lens.tolist()
max_block_len = (max(context_lens_list) + blocksize - 1) // blocksize * blocksize
block_tables = torch.zeros((batch_size, max_block_len), device='cuda', dtype=torch.int32)
x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8)
x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(torch.uint8)
x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(torch.uint8)
return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4), x_cast_back.to(x.dtype)
counter, block_idx_pool = 0, torch.randperm(num_blocks, device='cuda', dtype=torch.int32)
for i in range(batch_size):
num_blocks = ceil_div(context_lens_list[i], blocksize)
block_tables[i][:num_blocks] = block_idx_pool[counter: counter+num_blocks]
counter += num_blocks
def kv_cache_cast_to_fp4(x: torch.Tensor) -> torch.Tensor:
num_blocks, block_size, num_heads, head_dim = x.shape
assert num_heads == 1 and head_dim == 128
x_scaled, sf = per_token_cast_to_fp4(x.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
x_cast_back = cast_back_from_fp4(x_scaled, sf, gran_k=32, use_packed_ue8m0=True).view(num_blocks, block_size, 1, head_dim)
ref_logits = ref_fp8_paged_mqa_logits(q, kv_cache, weights, context_lens, block_tables, max_model_len, is_context_lens_2d)
positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1)
x_fp4 = torch.empty((num_blocks, block_size * (head_dim // 2 + 4)), device=x.device, dtype=torch.uint8)
x_fp4[ :, : block_size * head_dim // 2] = x_scaled.view(num_blocks, block_size * head_dim // 2).view(torch.uint8)
x_fp4[ :, block_size * head_dim // 2 :] = sf.view(num_blocks, block_size).view(torch.uint8)
return x_fp4.view(num_blocks, block_size, num_heads, head_dim // 2 + 4), x_cast_back.to(x.dtype)
if is_context_lens_2d:
context_lens_2d = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
context_lens_2d[:, next_n-1] = context_lens
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens_2d, blocksize, deep_gemm.get_num_sms())
logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False)
ref_neginf_mask = ~(positions < context_lens_2d.view(-1).unsqueeze(1))
else:
schedule_metadata = deep_gemm.get_paged_mqa_logits_metadata(context_lens, blocksize, deep_gemm.get_num_sms())
logits = deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True)
row_indices = torch.arange(batch_size * next_n, device='cuda') // next_n
next_n_offset = torch.arange(batch_size * next_n, device='cuda') % next_n
ref_neginf_mask = ~(positions <= (context_lens[row_indices] - next_n + next_n_offset).unsqueeze(1))
neginf_mask = (logits == float('-inf'))
assert torch.equal(neginf_mask, ref_neginf_mask)
def enumerate_paged_mqa_logits():
arch_major = get_arch_major()
for is_fp4 in ((True, False) if arch_major == 10 else (False, )):
for logits_dtype in (torch.float, torch.bfloat16):
for block_kv in ((32, 64) if arch_major == 10 else (64, )):
for use_2d_context_lens, clean_logits in [(True, False)]:
for batch_size in (256, ):
for next_n in (1, 2, 4, 5, 6) if arch_major == 10 else (1, 2):
for num_heads, head_dim in [(64, 128)]:
for avg_kv in (8192, 32768):
yield is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv
logits = logits.masked_fill(ref_neginf_mask, 0)
ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0)
diff = calc_diff(logits, ref_logits)
assert diff < 1e-3, f"{diff=}"
sum_lens = sum(context_lens.to(torch.int64))
tflops = 2 * sum_lens * next_n * heads * index_dim / 1e12
input_bytes = count_bytes(q_fp8, weights, context_lens) + sum_lens * (index_dim + 4) + (sum_lens / blocksize) * 4
output_bytes = sum_lens * next_n * 4
if is_context_lens_2d:
t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens_2d, block_tables, schedule_metadata, max_model_len, clean_logits=False),
'fp8_paged_mqa_logits')
else:
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True),
('fp8_paged_mqa_logits', 'clean_logits'))
clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens)
print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: '
f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, '
f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s', end='')
# noinspection PyUnboundLocalVariable
print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if not is_context_lens_2d else '')
print('Testing FP8/FP4 Paged MQA Logits:')
max_model_len = 111 * 1024
num_total_blocks = max_model_len * 5
for is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits():
# Generate random inputs
q = torch.randn((batch_size, next_n, num_heads, head_dim), device='cuda', dtype=torch.bfloat16)
kv_cache = torch.randn((num_total_blocks, block_kv, 1, head_dim), device='cuda', dtype=torch.bfloat16)
weights = torch.randn((batch_size * next_n, num_heads), device='cuda', dtype=torch.float)
context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (batch_size,), device='cuda', dtype=torch.int)
# Assign block tables
num_blocks_per_query = ceil_div(context_lens, block_kv)
block_table = torch.empty((batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int)
block_idx_pool = torch.randperm(num_total_blocks, device='cuda', dtype=torch.int)
offset = 0
for i, num_blocks in enumerate(num_blocks_per_query.tolist()):
block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks]
offset += num_blocks
# Calculate reference logits
ref_logits = ref_paged_mqa_logits(q, kv_cache, weights, context_lens, block_table, max_model_len, use_2d_context_lens)
# Quantize Q and KV cache to FP4 / FP8
if is_fp4:
q_fp4 = per_token_cast_to_fp4(q.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True)
q_in = (q_fp4[0].view(batch_size, next_n, num_heads, head_dim // 2), q_fp4[1].view(batch_size, next_n, num_heads))
q_simulated = cast_back_from_fp4(q_fp4[0], q_fp4[1], gran_k=32, use_packed_ue8m0=True).view(batch_size, next_n, num_heads, head_dim).to(torch.bfloat16)
kv_in, kv_simulated = kv_cache_cast_to_fp4(kv_cache)
else:
q_in = q.to(torch.float8_e4m3fn), None
q_simulated = q_in[0].to(torch.bfloat16)
kv_in, kv_simulated = kv_cache_cast_to_fp8(kv_cache)
# Calculate simulated reference logits
simulated_logits = ref_paged_mqa_logits(q_simulated, kv_simulated, weights, context_lens, block_table, max_model_len, use_2d_context_lens)
# Prepare masks and context lengths with NextN
positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1)
if use_2d_context_lens:
context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int()
# Ensure last token matches actual length
context_lens_nextn[:, -1] = context_lens
ref_neginf_mask = ~(positions < context_lens_nextn.view(-1, 1))
else:
context_lens_nextn = context_lens
offsets = torch.arange(batch_size * next_n, device='cuda')
limits = (context_lens[offsets // next_n] - next_n + offsets % next_n).unsqueeze(1)
ref_neginf_mask = ~(positions <= limits)
# Run Kernel
kernel_kwargs = dict(
q=q_in, kv_cache=kv_in, weights=weights,
context_lens=context_lens_nextn, block_table=block_table,
schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms()),
max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype
)
logits = deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs)
# Validation
assert logits.dtype == logits_dtype
logits = logits.to(torch.float)
if clean_logits:
assert torch.equal(logits == float('-inf'), ref_neginf_mask), "Mask mismatch"
logits_masked = logits.masked_fill(ref_neginf_mask, 0)
ref_masked = ref_logits.masked_fill(ref_neginf_mask, 0)
simulated_masked = simulated_logits.masked_fill(ref_neginf_mask, 0)
diff = calc_diff(logits_masked, ref_masked)
simulated_diff = calc_diff(logits_masked, simulated_masked)
assert diff < 0.02 if is_fp4 else 1e-3, f"Diff: {diff}"
assert simulated_diff < 5e-6, f"Simulated Diff: {simulated_diff}"
# Profiling
sum_lens = context_lens.sum().item()
tflops_calc = 2 * sum_lens * next_n * num_heads * head_dim / 1e12
kv_bytes_per_token = head_dim / (2 if is_fp4 else 1) + 4
total_bytes = count_bytes(q, weights) + sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize)
t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs), ('paged_mqa_logits', 'clean_logits'))
print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={batch_size:3}, NextN={next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: '
f'{tflops_calc / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, {total_bytes / t / 1e9:4.0f} GB/s', end='')
print(f' | clean: {clean_t*1e6:3.0f} us' if clean_logits else '')
print()
@@ -280,6 +358,5 @@ if __name__ == '__main__':
random.seed(0)
test_gemm_skip_head_mid()
test_mqa_logits()
test_paged_mqa_logits()

View File

@@ -11,7 +11,8 @@ from deep_gemm.testing import (
from generators import (
get_arch_major, layout_masked_to_psum, align,
enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous,
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous
generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous,
get_mk_alignment_for_contiguous_layout
)
@@ -56,6 +57,10 @@ def test_m_grouped_gemm_contiguous() -> None:
major_opt = 'N' if major_a.is_k_major() else 'T'
major_opt += 'T' if major_b.is_k_major() else 'N'
# Select best alignment
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
for test_alias in (False, True):
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
use_bf16=True, use_psum_layout=use_psum_layout)
@@ -65,8 +70,15 @@ def test_m_grouped_gemm_contiguous() -> None:
b = b if major_b.is_k_major() else b.mT
assert a[0].is_contiguous() and b[0].is_contiguous()
getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout)
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}'
if use_psum_layout:
for j in range(num_groups):
start = 0 if j == 0 else align(grouped_layout[j - 1], get_mk_alignment_for_contiguous_layout())
end = grouped_layout[j]
diff = calc_diff(d[start : end], ref_d[start : end])
assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}'
else:
diff = calc_diff(d, ref_d)
assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}'
m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b,
use_bf16=True, use_psum_layout=use_psum_layout)
@@ -91,6 +103,10 @@ def test_m_grouped_gemm_masked() -> None:
sum_t, max_t = 0, 0
sum_ops, sum_bytes = 0, 0
# Select best alignment
alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout(int(expected_m_per_group * 1.2))
deep_gemm.set_mk_alignment_for_contiguous_layout(alignment)
for i in range(num_tests):
a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k,
use_bf16=True, use_psum_layout=use_psum_layout)
@@ -111,7 +127,7 @@ def test_m_grouped_gemm_masked() -> None:
if masked_m[j].item() == 0:
continue
if use_psum_layout:
d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]]
d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout()): psum_m[j]]
else:
d_slice = d[j, :masked_m[j].item()]
diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()])
@@ -137,6 +153,9 @@ def test_m_grouped_gemm_masked() -> None:
def test_k_grouped_gemm_contiguous() -> None:
print('Testing k-grouped contiguous GEMM:')
# TODO: Support arbitrary alignment
deep_gemm.set_mk_alignment_for_contiguous_layout(128)
for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16):
for test_empty_groups in (False, True):

Some files were not shown because too many files have changed in this diff Show More