Embed DeepGEMM source (not submodule) for SM100 raw CUDA GEMM primitives

This commit is contained in:
2026-06-01 07:39:40 +00:00
parent dae83723a3
commit e3ea609ddd
145 changed files with 27360 additions and 1 deletions

Submodule third_party/DeepGEMM deleted from 714dd1a4a9

View File

@@ -0,0 +1,227 @@
name: ~Build wheel template
on:
workflow_call:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "The C++11 ABI to use for the build"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string
use-local-version:
description: "Use local version"
required: false
type: boolean
default: false
defaults:
run:
shell: bash -x -e -u -o pipefail {0}
jobs:
build-wheel:
runs-on: ${{ inputs.runs-on }}
name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }})
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ref: ${{ inputs.release-version }}
submodules: recursive
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: ${{ inputs.python-version }}
- name: Set CUDA and PyTorch versions
run: |
echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV
echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
# https://github.com/easimon/maximize-build-space/blob/master/action.yml
# https://github.com/easimon/maximize-build-space/tree/test-report
run: |
sudo rm -rf /usr/share/dotnet
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
- name: Set up swap space
if: runner.os == 'Linux'
uses: pierotofy/set-swap-space@v1.0
with:
swap-size-gb: 10
- name: Install CUDA ${{ inputs.cuda-version }}
if: ${{ inputs.cuda-version != 'cpu' }}
uses: Jimver/cuda-toolkit@v0.2.26
id: cuda-toolkit
with:
cuda: ${{ inputs.cuda-version }}
linux-local-args: '["--toolkit"]'
# default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1
# method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }}
method: "network"
- name: Install additional CUDA libraries
run: |
CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 "-" $2'})
sudo apt-get update
sudo apt-get install -y libcusparse-$CUDA_VERSION libcusolver-$CUDA_VERSION
sudo apt-get clean
- name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }}
run: |
pip install --upgrade pip
# With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error
# AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable
pip install typing-extensions==4.12.2
# We want to figure out the CUDA version to download pytorch
# e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# This code is ugly, maybe there's a better way to do this.
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \
maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
)
if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then
# pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
# Can't use --no-deps because we need cudnn etc.
# Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001
pip install jinja2
pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
else
pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
fi
nvcc --version
python --version
python -c "import torch; print('PyTorch:', torch.__version__)"
python -c "import torch; print('CUDA:', torch.version.cuda)"
python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)"
- name: Restore build cache
uses: actions/cache/restore@v4
with:
path: build.tar
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
restore-keys: |
build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-
- name: Unpack build cache
run: |
echo ::group::Adjust timestamps
sudo find / -exec touch -t 197001010000 {} + || true
echo ::endgroup::
if [ -f build.tar ]; then
find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} +
tar -xpvf build.tar -C .
else
echo "No build.tar found, skipping"
fi
ls -al ./
ls -al build/ || true
ls -al csrc/ || true
- name: Build wheel
id: build_wheel
run: |
# We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6
# https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810
# However this still fails so I'm using a newer version of setuptools
pip install setuptools==75.8.0
pip install ninja packaging wheel
export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH
export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH
# Limit MAX_JOBS otherwise the github runner goes OOM
# nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM
export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2)
export NVCC_THREADS=2
export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX"
export DG_USE_LOCAL_VERSION=${{ inputs.use-local-version && '1' || '0' }}
# 5h timeout since GH allows max 6h and we want some buffer
EXIT_CODE=0
timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$?
if [ $EXIT_CODE -eq 0 ]; then
tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }}
wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2")
ls dist/*whl |xargs -I {} mv {} dist/${wheel_name}
echo "wheel_name=${wheel_name}" >> $GITHUB_ENV
fi
# Store exit code in GitHub env for later steps
echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT"
# Do not fail the job if timeout killed the build
exit $EXIT_CODE
- name: Log build logs after timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
run: |
ls -al ./
tar -cvf build.tar . --atime-preserve=replace
- name: Save build cache timeout
if: always() && steps.build_wheel.outputs.build_exit_code == 124
uses: actions/cache/save@v4
with:
key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }}
path: build.tar
- name: Log Built Wheels
run: |
ls dist
- name: Get Release with tag
id: get_current_release
uses: joutvhu/get-release@v1
with:
tag_name: ${{ inputs.release-version }}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload Release Asset
id: upload_release_asset
if: inputs.upload-to-release
uses: actions/upload-release-asset@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
upload_url: ${{ steps.get_current_release.outputs.upload_url }}
asset_path: ./dist/${{env.wheel_name}}
asset_name: ${{env.wheel_name}}
asset_content_type: application/*

View File

@@ -0,0 +1,53 @@
name: Build wheels
on:
workflow_dispatch:
inputs:
runs-on:
description: "The runner to use for the build"
required: true
type: string
default: ubuntu-22.04
python-version:
description: "The Python version to use for the build"
required: true
type: string
cuda-version:
description: "The CUDA version to use for the build"
required: true
type: string
torch-version:
description: "The PyTorch version to use for the build"
required: true
type: string
cxx11_abi:
description: "Enable torch flag C++11 ABI (TRUE/FALSE)"
required: true
type: string
upload-to-release:
description: "Upload wheel to this release"
required: false
type: boolean
default: false
release-version:
description: "Upload wheel to this release"
required: false
type: string
use-local-version:
description: "Use local version"
required: false
type: boolean
default: false
jobs:
build-wheels:
uses: ./.github/workflows/_build.yml
with:
runs-on: ${{ inputs.runs-on }}
python-version: ${{ inputs.python-version }}
cuda-version: ${{ inputs.cuda-version }}
torch-version: ${{ inputs.torch-version }}
cxx11_abi: ${{ inputs.cxx11_abi }}
upload-to-release: ${{ inputs.upload-to-release }}
release-version: ${{ inputs.release-version }}
use-local-version: ${{ inputs.use-local-version }}

View File

@@ -0,0 +1,95 @@
# This workflow will:
# - Create a new Github release
# - Build wheels for supported architectures
# - Deploy the wheels to the Github release
# - Release the static code to PyPi
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
name: Build wheels and deploy
on:
create:
tags:
- v*
jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
outputs:
release-version: ${{ steps.extract_branch.outputs.branch }}
steps:
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
build_wheels:
name: Build Wheel
needs: setup_release
strategy:
fail-fast: false
matrix:
# Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the
# manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
os: [ubuntu-22.04]
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"]
cuda-version: ["12.9.1"]
# We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
# Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
# Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs)
# when building without C++11 ABI and using it on nvcr images.
cxx11_abi: ["FALSE", "TRUE"]
exclude:
# see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix
# Pytorch < 2.5 does not support Python 3.13
- torch-version: "2.4.0"
python-version: "3.13"
uses: ./.github/workflows/_build.yml
with:
runs-on: ${{ matrix.os }}
python-version: ${{ matrix.python-version }}
cuda-version: ${{ matrix.cuda-version }}
torch-version: ${{ matrix.torch-version }}
cxx11_abi: ${{ matrix.cxx11_abi }}
release-version: ${{ needs.setup_release.outputs.release-version }}
upload-to-release: true
use-local-version: false
publish_package:
name: Publish package
needs: [build_wheels]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip install ninja packaging wheel twine
# Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv)
pip install setuptools==75.8.0
# We don't want to download anything CUDA-related here
pip install torch --index-url https://download.pytorch.org/whl/cpu
- name: Build core package
env:
DG_USE_LOCAL_VERSION: "0"
DG_SKIP_CUDA_BUILD: "1"
run: |
python setup.py sdist --dist-dir=dist
- name: Deploy
env:
TWINE_USERNAME: "__token__"
TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
run: |
python -m twine upload dist/*

24
third_party/DeepGEMM/.gitignore vendored Normal file
View File

@@ -0,0 +1,24 @@
cmake-build-*
.idea
.DS_Store
build
dist
*.egg-info
*.pyc
# Third-party links created by `setup.py develop`
deep_gemm/include/cute
deep_gemm/include/cutlass
# VS Code settings
/.vscode
# clangd settings
/.clang*
/.cache
# Generated stub files
stubs/
# Symlinks to compiled extensions
deep_gemm/*.so

6
third_party/DeepGEMM/.gitmodules vendored Normal file
View File

@@ -0,0 +1,6 @@
[submodule "third-party/cutlass"]
path = third-party/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "third-party/fmt"]
path = third-party/fmt
url = https://github.com/fmtlib/fmt.git

32
third_party/DeepGEMM/CMakeLists.txt vendored Normal file
View File

@@ -0,0 +1,32 @@
# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT
cmake_minimum_required(VERSION 3.10)
project(deep_gemm LANGUAGES CXX CUDA)
set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi -Wno-deprecated-declarations")
set(CUDA_SEPARABLE_COMPILATION ON)
list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG")
list(APPEND CUDA_NVCC_FLAGS "-O3")
list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage")
set(USE_SYSTEM_NVTX on)
set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile")
set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}")
find_package(CUDAToolkit REQUIRED)
find_package(pybind11 REQUIRED)
find_package(Torch REQUIRED)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CUDA_STANDARD 20)
include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include)
include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS})
link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs)
# The main Python API entrance
pybind11_add_module(_C csrc/python_api.cpp)
target_link_libraries(_C PRIVATE ${TORCH_LIBRARIES} torch_python)
# Enable kernel code indexing with CMake-based IDEs
cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu)

21
third_party/DeepGEMM/LICENSE vendored Normal file
View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 DeepSeek
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

207
third_party/DeepGEMM/README.md vendored Normal file
View File

@@ -0,0 +1,207 @@
# DeepGEMM
DeepGEMM is a unified, high-performance tensor core kernel library that brings together the key computation primitives of modern large language models — GEMMs (FP8, FP4, BF16), fused MoE with overlapped communication (Mega MoE), MQA scoring for the lightning indexer, HyperConnection (HC), and more — into a single, cohesive CUDA codebase. All kernels are compiled at runtime via a lightweight Just-In-Time (JIT) module, requiring no CUDA compilation during installation.
DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), but avoids heavy reliance on their templates or algebras. The library is designed for simplicity, with only a limited number of core kernel functions, making it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques.
Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes.
## News
- 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more.
- Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details.
- For Mega MoE benchmarks, refer to [#316](https://github.com/deepseek-ai/DeepGEMM/pull/316).
- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2.
- Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details.
- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module.
- NVRTC and post-compilation SASS optimization are all disabled.
- NVRTC will be supported later.
- As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported.
- Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details.
- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details.
- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases).
- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details.
## Quick start
### Requirements
- NVIDIA SM90 or SM100 architecture GPU
- Python 3.8 or higher
- Compilers with C++20 support
- CUDA Toolkit:
- CUDA 12.3 or higher for SM90
- **We highly recommend 12.9 or higher for the best performance**
- CUDA 12.9 or higher for SM100
- PyTorch 2.1 or higher
- CUTLASS 4.0 or higher (could be cloned by Git submodule)
- `{fmt}` library (could be cloned by Git submodule)
### Development
```bash
# Submodule must be cloned
git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git
cd DeepGEMM
# Link some essential includes and build the CPP JIT module
cat develop.sh
./develop.sh
```
### Installation
```bash
cat install.sh
./install.sh
```
Then, import `deep_gemm` in your Python project, and enjoy!
## Interfaces
#### Notices
This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: `D = C + A @ B`. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, `fp8_gemm_nt` will do a `D = C + A @ B.T`
For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different:
- SM90 requires scaling factors in FP32 format.
- SM100 requires scaling factors in packed [UE8M0](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) format, which packs 4 UE8M0 into a single `torch.int`.
Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves.
#### Normal dense GEMMs (non-grouped)
To perform a basic non-grouped FP8 GEMM, call the `fp8_gemm_{nt, nn, tn, tt}` function. For more details, please refer to the function documentation.
#### Grouped GEMMs (contiguous layout)
Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_mk_alignment_for_contiguous_layout()`). For more information, please refer to the `m_grouped_fp8_gemm_{nt, nn}_contiguous` function documentation.
We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to `k_grouped_fp8_gemm_tn_contiguous` for more information.
#### Grouped GEMMs (masked layout)
During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions.
Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input.
#### V3.2 MQA kernels for the indexer
The kernel family has two versions, non-paged (for prefilling) and paged (for decoding).
Take the non-paged version `fp8_mqa_logits` as an example. It has 6 inputs:
- `q`, E4M3 tensor with shape `[seq_len, num_heads, head_dim]`
- `kv`, E4M3 tensor (shaped as `[seq_len_kv, head_dim]`) with float SF (shaped as `[seq_len_kv]`)
- `weights`, float tensor with shape `[seq_len, num_heads]`
- `cu_seq_len_k_start` and `cu_seq_len_k_end`, int tensor with shape `[seq_len]`
- `clean_logits`, whether to clean the unfilled logits into `-inf`
The output tensor is shaped as `[seq_len, seq_len_kv]`, indicating token-to-token logits.
For each token `i` in `q`, it will iterate all tokens `j` from `[cu_seq_len_k_start[i], cu_seq_len_k_end[i])`,
and calculate the logit `out[i, j]` as:
```python
kv_j = kv[0][j, :] * kv[1][j].unsqueeze(1) # [head_dim]
out_ij = q[i, :, :] @ kv_j # [num_heads]
out_ij = out_ij.relu() * weights[i, :] # [num_heads]
out_ij = out_ij.sum() # Scalar
```
For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`.
#### Mega MoE
Mega MoE fuses and overlaps EP dispatch, linear 1 (FP8xFP4), SwiGLU, linear 2 (FP8xFP4), and EP combine into a single mega-kernel, overlapping NVLink communication and tensor core computation. It requires multi-process launch with symmetric memory. Usage:
```python
# Allocate symmetric memory buffer
# NOTES: requires PyTorch >= 2.9
buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden
)
# Transform weights (FP4 with UE8M0 SF) into the required layout
transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights)
# Copy inputs into the buffer before each call
# You may fuse these into previous kernels
buffer.x[:num_tokens].copy_(x_fp8)
buffer.x_sf[:num_tokens].copy_(x_sf)
buffer.topk_idx[:num_tokens].copy_(topk_idx)
buffer.topk_weights[:num_tokens].copy_(topk_weights)
# Run the fused mega MoE kernel
y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda')
deep_gemm.fp8_fp4_mega_moe(y, transformed_l1, transformed_l2, buffer)
```
For the full example with multi-process setup and benchmarking, please refer to `tests/test_mega_moe.py`.
#### Utilities
The library provides some utility functions besides the above kernels:
- `deep_gemm.set_num_sms` / `get_num_sms`: set/get the maximum SM count to use
- `deep_gemm.set_tc_util` / `get_tc_util`: set/get an approximated tensor core utilization ratio
- `deep_gemm.set_pdl` / `get_pdl`: enable/disable Programmatic Dependent Launch (PDL)
- `deep_gemm.set_mk_alignment_for_contiguous_layout` / `get_mk_alignment_for_contiguous_layout`: set/get the group-level M/K alignment for contiguous layout
- `deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout`: get the theoretical minimum M/K alignment
- `deep_gemm.set_ignore_compile_dims`: configure dimensions to ignore during JIT compilation
- `deep_gemm.set_block_size_multiple_of`: constrain block sizes to be multiples of a given value
- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into the required layout
- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size
- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor
- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0)
- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel
The library also provides some environment variables, which may be useful:
- General
- `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default
- `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default
- JIT cache
- `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default
- Compiler selection
- `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default
- `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME`
- `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default
- Compiler output
- `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default
- `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default
- `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default
- `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default
- Debug and profiling
- `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default
- `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default
- `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default
- `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default
- `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default
- `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default
- Build options
- `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default
- `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default
- `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default
For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.
## Acknowledgement
DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers!
## License
This code repository is released under [the MIT License](LICENSE).
## Citation
```bibtex
@misc{deepgemm2025,
title={DeepGEMM: clean and efficient BLAS kernel library on GPU},
author={Chenggang Zhao and Zhean Xu and Liang Zhao and Jiashi Li and Chenhao Xu and Anyi Xu and Shengyu Liu and Kexing Zhou and Kuai Yu},
year={2025},
publisher = {GitHub},
howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}},
}
```

12
third_party/DeepGEMM/build.sh vendored Executable file
View File

@@ -0,0 +1,12 @@
# Change current directory into project root
original_dir=$(pwd)
script_dir=$(realpath "$(dirname "$0")")
cd "$script_dir"
# Remove old dist file, build files, and install
rm -rf build dist
rm -rf *.egg-info
python setup.py bdist_wheel
# Open users' original directory
cd "$original_dir"

View File

@@ -0,0 +1,453 @@
#pragma once
#include "../utils/compatibility.hpp"
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp"
#include "../jit_kernels/impls/smxx_clean_logits.hpp"
#endif
#include "layout.hpp"
namespace deep_gemm::attention {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void fp8_gemm_nt_skip_head_mid(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::tuple<int, int, int>& head_splits,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
}
// D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto [m , k ] = get_shape<2>(a.first);
const auto [n , k_] = get_shape<2>(b.first);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Check head splits and N
const auto [left, mid, right] = head_splits;
DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid);
// Do nothing if the problem is empty
if (m == 0)
return;
// Transform SFA and SFB into compute-required layout
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, disable_ue8m0_cast);
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
// Dispatch into different implements
const auto arch_major = device_runtime->get_arch_major();
const auto epilogue_type = fmt::format("epilogue::transform::EpilogueHeadSplits<{}, {}, {}>", left, mid, right);
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) {
const auto major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
// NOTES: Only granularity 128 and FP8 are exposed in the API
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k,
128, 128, major_a, major_b, compiled_dims, epilogue_type);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static torch::Tensor fp8_fp4_mqa_logits(const std::tuple<torch::Tensor, std::optional<torch::Tensor>>& q,
const std::tuple<torch::Tensor, torch::Tensor>& kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const bool& clean_logits,
const int& max_seqlen_k,
const at::ScalarType& logits_dtype) {
const auto [q_fp, q_sf] = q;
const auto [kv_fp, kv_sf] = kv;
const bool is_fp4 = q_sf.has_value();
int seq_len, seq_len_kv, num_heads, head_dim;
if (is_fp4) {
// Check FP4 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
// Check SF Q
auto [_seq_len, _num_heads] = get_shape<2>(q_sf.value());
DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads);
DG_HOST_ASSERT(q_sf.value().is_contiguous());
DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32);
// Check FP4 KV
int _head_dim;
std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp);
_head_dim *= 2;
DG_HOST_ASSERT(head_dim == _head_dim);
DG_HOST_ASSERT(kv_fp.is_contiguous());
DG_HOST_ASSERT(kv_fp.scalar_type() == kPackedFP4);
// Check SF KV
auto [_seq_len_kv] = get_shape<1>(kv_sf);
DG_HOST_ASSERT(seq_len_kv == _seq_len_kv);
DG_HOST_ASSERT(kv_sf.is_contiguous());
DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kInt32);
} else {
// Check FP8 Q
std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check FP4 KV
int _head_dim;
std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp);
DG_HOST_ASSERT(head_dim == _head_dim);
DG_HOST_ASSERT(kv_fp.is_contiguous());
DG_HOST_ASSERT(kv_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check SF KV
auto [_seq_len_kv] = get_shape<1>(kv_sf);
DG_HOST_ASSERT(seq_len_kv == _seq_len_kv);
DG_HOST_ASSERT(kv_sf.is_contiguous());
DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kFloat);
}
// Check weights
auto [_seq_len, _num_heads] = get_shape<2>(weights);
DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads);
DG_HOST_ASSERT(weights.stride(1) == 1);
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
// Check cu_seq_len_k_start
DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len);
DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt);
// Check cu_seq_len_k_end
DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len);
DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous());
DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt);
// Allocate output
constexpr int block_qh = 128;
constexpr int block_kv = 256;
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);
torch::Tensor logits;
int aligned_seq_len = align(seq_len, block_q), stride_logits;
if (max_seqlen_k == 0) {
// Logits stride must be 16-byte aligned
stride_logits = align(seq_len_kv + block_kv, 8);
logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype));
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)});
} else {
stride_logits = align(max_seqlen_k, block_kv);
logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype));
logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, max_seqlen_k)});
DG_HOST_ASSERT(not clean_logits);
}
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (is_fp4 and arch_major == 10) {
sm100_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_mqa_logits(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype,
seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Clean unfilled logits
if (clean_logits)
smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, stride_logits);
return logits;
}
static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms, const std::optional<torch::Tensor>& indices) {
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const int batch_size = context_lens.size(0);
const int next_n = context_lens.size(1);
const bool is_varlen = indices.has_value();
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
DG_HOST_ASSERT(context_lens.is_contiguous());
// Create metadata tensor
auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options());
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (is_varlen) {
const auto& indices_tensor = indices.value();
DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32));
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr<int>());
} else if (arch_major == 9 or arch_major == 10) {
DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32));
smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, false, nullptr);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
return schedule_metadata;
}
static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple<torch::Tensor, std::optional<torch::Tensor>>& q,
const torch::Tensor& fused_kv_cache,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits,
const at::ScalarType& logits_dtype,
const std::optional<torch::Tensor>& indices) {
const auto [q_fp, q_sf] = q;
const bool is_fp4 = q_sf.has_value();
torch::Tensor kv_cache, kv_cache_sf;
int batch_size, next_n, num_heads, head_dim;
int num_kv_blocks, block_kv;
int kv_cache_stride_bytes;
int block_table_stride = block_table.stride(0);
int num_sms = device_runtime->get_num_sms();
if (is_fp4) {
// Check FP4 Q
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
head_dim *= 2;
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4);
// Check SF Q
auto [_batch_size, _next_n, _num_heads] = get_shape<3>(q_sf.value());
DG_HOST_ASSERT(batch_size == _batch_size and next_n == _next_n and num_heads == _num_heads);
DG_HOST_ASSERT(q_sf.value().is_contiguous());
DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32);
// Check fused KV cache
int num_heads_kv, fp4_with_sf_bytes;
std::tie(num_kv_blocks, block_kv, num_heads_kv, fp4_with_sf_bytes) = get_shape<4>(fused_kv_cache);
DG_HOST_ASSERT(block_kv == 32 or block_kv == 64);
DG_HOST_ASSERT(num_heads_kv == 1 and fp4_with_sf_bytes == head_dim / 2 + static_cast<int>(sizeof(int)));
DG_HOST_ASSERT(fused_kv_cache.stride(1) == fp4_with_sf_bytes and fused_kv_cache.stride(3) == 1);
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
// Derive FP4 values and SF tensor
kv_cache_stride_bytes = fused_kv_cache.stride(0);
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(int) == 0);
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim / 2},
{kv_cache_stride_bytes, head_dim / 2, 1},
torch::TensorOptions().dtype(kPackedFP4)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim / 2,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(int)), 1},
torch::TensorOptions().dtype(torch::kInt32)
);
} else {
// Check FP8 Q
std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp);
DG_HOST_ASSERT(next_n >= 1);
DG_HOST_ASSERT(num_heads == 32 or num_heads == 64);
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
DG_HOST_ASSERT(q_fp.is_contiguous());
DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn);
// Check fused KV cache
int num_heads_kv, head_dim_with_sf;
std::tie(num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf) = get_shape<4>(fused_kv_cache);
DG_HOST_ASSERT(block_kv == 32 or block_kv == 64);
DG_HOST_ASSERT(num_heads_kv == 1 and head_dim_with_sf == head_dim + static_cast<int>(sizeof(float)));
DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf and fused_kv_cache.stride(3) == 1);
DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte);
// Derive FP8 values and SF tensor
kv_cache_stride_bytes = fused_kv_cache.stride(0);
DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0);
kv_cache = torch::from_blob(
fused_kv_cache.data_ptr(),
{num_kv_blocks, block_kv, head_dim},
{kv_cache_stride_bytes, head_dim, 1},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn)
);
kv_cache_sf = torch::from_blob(
fused_kv_cache.data_ptr<uint8_t>() + block_kv * head_dim,
{num_kv_blocks, block_kv},
{kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 1},
torch::TensorOptions().dtype(torch::kFloat32)
);
// Weights must be contiguous for FP8
DG_HOST_ASSERT(weights.is_contiguous());
}
// Check weights
auto [_batch_size_next_n, _num_heads] = get_shape<2>(weights);
DG_HOST_ASSERT(_batch_size_next_n == batch_size * next_n and _num_heads == num_heads);
DG_HOST_ASSERT(weights.stride(1) == 1);
DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat);
// Check block table
auto [_batch_size, _max_block_len] = get_shape<2>(block_table);
DG_HOST_ASSERT(_batch_size == batch_size);
DG_HOST_ASSERT(block_table.stride(1) == 1);
DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt);
// Check indices
const bool is_varlen = indices.has_value();
const auto arch_major = device_runtime->get_arch_major();
const auto indices_tensor = indices.value_or(torch::Tensor());
if (is_varlen) {
DG_HOST_ASSERT(arch_major == 10 and next_n == 1);
DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size);
DG_HOST_ASSERT(indices_tensor.is_contiguous());
DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt);
}
// Check schedule metadata
auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta);
DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2);
DG_HOST_ASSERT(schedule_meta.is_contiguous());
DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt);
// Check context lengths
// NOTES: Only 2D context lens is supported for now
DG_HOST_ASSERT(context_lens.dim() == 2);
const bool is_context_lens_2d = true;
const auto [__batch_size, _next_n] = get_shape<2>(context_lens);
DG_HOST_ASSERT(batch_size == __batch_size and next_n == _next_n);
DG_HOST_ASSERT(context_lens.is_contiguous());
DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt);
// Allocate output
constexpr int split_kv = 256;
const auto aligned_max_context_len = align(max_context_len, split_kv);
auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q_fp.options().dtype(logits_dtype));
logits = logits.slice(-1, 0, max_context_len);
DG_HOST_ASSERT(logits_dtype == torch::kFloat32 or logits_dtype == torch::kBFloat16);
// Dispatch implementation
if (is_fp4 and arch_major == 10) {
sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) {
smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta,
logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d,
is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Clean unfilled logits
if (clean_logits) {
DG_HOST_ASSERT(not is_context_lens_2d);
smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len);
}
return logits;
}
// Legacy API wrappers
static torch::Tensor fp8_mqa_logits(const torch::Tensor& q,
const std::tuple<torch::Tensor, torch::Tensor>& kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const bool& clean_logits,
const int& max_seqlen_k) {
return fp8_fp4_mqa_logits(std::make_tuple(q, std::nullopt), kv, weights,
cu_seq_len_k_start, cu_seq_len_k_end,
clean_logits, max_seqlen_k, torch::kFloat);
}
static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& fused_kv_cache,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& max_context_len,
const bool& clean_logits,
const std::optional<torch::Tensor>& indices) {
return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights,
context_lens, block_table, schedule_meta,
max_context_len, clean_logits, torch::kFloat, indices);
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
py::arg("recipe") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_fp4_mqa_logits", &fp8_fp4_mqa_logits,
py::arg("q"), py::arg("kv"), py::arg("weights"),
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
py::arg("clean_logits") = true,
py::arg("max_seqlen_k") = 0,
py::arg("logits_dtype") = torch::kFloat32);
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"),
py::arg("indices") = std::nullopt);
m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
py::arg("max_context_len"),
py::arg("clean_logits") = false,
py::arg("logits_dtype") = torch::kFloat32,
py::arg("indices") = std::nullopt);
// Legacy API
m.def("fp8_mqa_logits", &fp8_mqa_logits,
py::arg("q"), py::arg("kv"), py::arg("weights"),
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
py::arg("clean_logits") = true,
py::arg("max_seqlen_k") = 0);
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
py::arg("max_context_len"), py::arg("clean_logits") = false,
py::arg("indices") = std::nullopt);
#endif
}
} // namespace deep_gemm::attention

View File

@@ -0,0 +1,231 @@
#pragma once
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/layout.hpp"
#include "../utils/compatibility.hpp"
#include "gemm.hpp"
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
#endif
namespace deep_gemm::einsum {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d,
const std::optional<torch::Tensor>& c) {
// Currently FP32 only support the accumulated expression
if (d.scalar_type() == torch::kFloat) {
DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides());
} else {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(not c.has_value());
const auto workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32));
DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(),
c10::cuda::getCurrentCUDAStream()));
bmk_bnk_mn(a, b, workspace, workspace);
// This line has an implicit FP32-to-BF16 casting
d.copy_(workspace);
return;
}
DG_HOST_ASSERT(a.is_contiguous());
DG_HOST_ASSERT(b.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
const auto [s , m, k ] = get_shape<3>(a);
const auto [s_, n, k_] = get_shape<3>(b);
DG_HOST_ASSERT(s == s_ and k == k_);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else if (arch_major == 10) {
sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) {
const auto [b , h , r ] = get_shape<3>(A);
const auto [h_, d , r_] = get_shape<3>(B);
const auto [b_, h__, d_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (use_cublaslt) {
cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
sm90_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else if (arch_major == 10) {
sm100_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) {
const auto [b , h , d ] = get_shape<3>(A);
const auto [h_, d_ , r ] = get_shape<3>(B);
const auto [b_, h__, r_] = get_shape<3>(D);
DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__);
DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1);
DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1);
DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (use_cublaslt) {
cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 9) {
sm90_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else if (arch_major == 10) {
sm100_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void einsum(const std::string& expr,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const bool& use_cublaslt) {
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
if (c.has_value()) {
DG_HOST_ASSERT(c->scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
}
// Some hardcoded Einstein sum kernels
// TODO: support any expression
// TODO: canonicalize expression
if (expr == "bmk,bnk->mn") {
DG_HOST_ASSERT(not use_cublaslt);
bmk_bnk_mn(a, b, d, c);
} else if (expr == "bhr,hdr->bhd") {
DG_HOST_ASSERT(not c.has_value());
bhr_hdr_bhd(a, b, d, use_cublaslt);
} else if (expr == "bhd,hdr->bhr") {
DG_HOST_ASSERT(not c.has_value());
bhd_hdr_bhr(a, b, d, use_cublaslt);
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
}
}
static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
std::optional<std::tuple<int, int, int>> recipe,
const std::string& compiled_dims) {
// Shape must be `[B, M, K] @ [B, N, K].T`
const auto major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
const auto major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
DG_HOST_ASSERT(a.stride(-1) == 1 or a.stride(-2) == 1);
DG_HOST_ASSERT(b.stride(-1) == 1 or b.stride(-2) == 1);
DG_HOST_ASSERT(d.stride(-1) == 1);
// Type and shape checks
const auto [batch_size , m , k ] = get_shape<3>(a);
const auto [batch_size_ , n , k_] = get_shape<3>(b);
const auto [batch_size__, m_, n_] = get_shape<3>(d);
DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size_);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(b.scalar_type() == torch::kFloat8_e4m3fn);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Early return for trivial cases
if (batch_size == 0 or gemm::early_return(m, n, k, d, c))
return;
// Transform scaling factors
const auto [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims);
} else {
const auto major_sfb = get_major_type_ab(sfb);
DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128);
sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
}
static void fp8_einsum(const std::string& expr,
const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe) {
// Some hardcoded Einstein sum kernels
const auto arch_major = device_runtime->get_arch_major();
if (expr == "bhr,hdr->bhd") {
// Permute dims to satisfy the order of (batch_size, m, n, k)
// (batch_size, m, n, k): (h, b, d, r)
const auto perm_a = a.first.permute({1, 0, 2});
const auto perm_sfa = a.second.permute({1, 0, 2});
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,hdr->bhr" and arch_major == 10) {
// (batch_size, m, n, k): (h, b, r, d)
const auto perm_a = a.first.permute({1, 0, 2});
const auto perm_sfa = a.second.permute({1, 0, 2});
const auto perm_b = b.first.permute({0, 2, 1});
const auto perm_sfb = b.second.permute({0, 2, 1});
const auto perm_d = d.permute({1, 0, 2});
const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt;
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk");
} else if (expr == "bhd,bhr->hdr" and arch_major == 10) {
// (batch_size, m, n, k): (h, d, r, b)
const auto perm_a = a.first.permute({1, 2, 0});
const auto perm_sfa = a.second.permute({1, 2, 0});
const auto perm_b = b.first.permute({1, 2, 0});
const auto perm_sfb = b.second.permute({1, 2, 0});
fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn");
} else {
DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr));
}
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("einsum", &einsum,
py::arg("expr"), py::arg("a"), py::arg("b"),
py::arg("d"), py::arg("c") = std::nullopt,
py::arg("use_cublaslt") = false);
m.def("fp8_einsum", &fp8_einsum,
py::arg("expr"), py::arg("a"), py::arg("b"),
py::arg("d"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 128, 128));
#endif
}
} // namespace deep_gemm::einsum

715
third_party/DeepGEMM/csrc/apis/gemm.hpp vendored Normal file
View File

@@ -0,0 +1,715 @@
#pragma once
#include "../utils/compatibility.hpp"
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp"
#include "../jit_kernels/impls/sm90_bf16_gemm.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp"
#include "../jit_kernels/impls/sm100_bf16_gemm.hpp"
#endif
#include "../jit_kernels/impls/smxx_cublaslt.hpp"
#include "layout.hpp"
namespace deep_gemm::gemm {
static bool early_return(const int& m, const int &n, const int& k,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
// Do nothing if the problem is empty
if (m == 0 or n == 0)
return true;
// Checks
const bool is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr();
if (is_cd_same)
DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides());
if (c.has_value()) {
check_major_type_cd(c.value());
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat);
}
// No accumulation
if (k == 0) {
if (not is_cd_same)
c.has_value() ? d.copy_(c.value()) : d.zero_();
return true;
}
// With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D)
if (c.has_value() and not is_cd_same)
d.copy_(c.value());
return false;
}
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void fp8_fp4_gemm_nt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
std::optional<std::tuple<int, int, int>> recipe,
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[M, K] @ [N, K].T`
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
if (fp8_requires_k_major()) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
}
// C/D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Early return for trivial cases
if (early_return(m, n, k, d, c))
return;
// Transform SFA and SFB into compute-required layout
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast);
// Dispatch into different implements
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value());
if (gran_n == 1) {
sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
const auto major_sfb = get_major_type_ab(sfb);
sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims);
}
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static void fp8_fp4_gemm_nn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
}
static void fp8_fp4_gemm_tn(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)},
{b.first.transpose(0, 1), b.second.transpose(0, 1)},
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
}
static void fp8_fp4_gemm_tt(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b,
d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast);
}
static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
std::optional<std::tuple<int, int, int>> recipe,
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
if (fp8_requires_k_major())
DG_HOST_ASSERT(major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(grouped_layout.is_contiguous());
// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
// Layout checks
if (use_psum_layout) {
const auto [num_groups_] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto [m__] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
// D must be N-major
check_major_type_cd(d);
// Do nothing if empty
if (m == 0)
return;
// Transform SFA and SFB into compute-required layout
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, num_groups, disable_ue8m0_cast);
// Dispatch implementation
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, major_sfb,
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout,
num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b,
compiled_dims, use_psum_layout, expected_m_for_psum_layout);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast,
const bool& use_psum_layout) {
m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)},
d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt);
}
static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& expected_m,
std::optional<std::tuple<int, int, int>> recipe,
std::optional<std::tuple<int, int>> recipe_a,
std::optional<std::tuple<int, int>> recipe_b,
const std::string& compiled_dims,
const bool& disable_ue8m0_cast) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto major_a = get_major_type_ab(a.first);
const auto major_b = get_major_type_ab(b.first);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());
// Type and shape checks
const auto arch_major = device_runtime->get_arch_major();
const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major);
const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major);
const auto [num_groups__, m_, n_] = get_shape<3>(d);
const auto num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
// D must be N-major
check_major_type_cd(d);
// Transform scaling factors
const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout(
a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast);
// Dispatch implementation
if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) {
const auto major_sfb = get_major_type_ab(sfb);
sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims);
} else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) {
sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m,
num_groups, m, n, k, expected_m, gran_k_a, gran_k_b,
major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types");
}
}
static void k_grouped_fp8_gemm_tn_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe,
const std::string& compiled_dims) {
// Must be 1D1D kernel
DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1);
const int gran_k = std::get<2>(recipe);
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
// Shape checks
const auto [num_groups, m, n] = get_shape<3>(d);
const auto [sum_k_ , m_] = get_shape<2>(a.first);
const auto [sum_k__, n_] = get_shape<2>(b.first);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
// Contiguity checks
DG_HOST_ASSERT(a.first.is_contiguous());
DG_HOST_ASSERT(b.first.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous());
// Early return for trivial cases
if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c))
return;
// Transform SF with padding
const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 10) {
sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, gran_k,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void k_grouped_fp8_gemm_nt_contiguous(const std::pair<torch::Tensor, torch::Tensor>& a,
const std::pair<torch::Tensor, torch::Tensor>& b,
const torch::Tensor& d,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::optional<torch::Tensor>& c,
const std::tuple<int, int, int>& recipe,
const std::string& compiled_dims) {
// Must be 1D1D kernel
DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128));
// Shape checks
const auto [num_groups, m, n] = get_shape<3>(d);
const auto sum_mk = a.first.numel();
const auto sum_nk = b.first.numel();
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(sum_mk == static_cast<int64_t>(sum_k) * m);
DG_HOST_ASSERT(sum_nk == static_cast<int64_t>(sum_k) * n);
// Contiguity checks
DG_HOST_ASSERT(a.first.is_contiguous());
DG_HOST_ASSERT(b.first.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous());
// Early return for trivial cases
if (early_return(m, n, accumulate(ks.begin(), ks.end(), 0), d, c))
return;
// Transform SF with padding
const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe);
const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe);
// Allocate tensormap buffer
// `4` means the double buffering for both A and B operands (2 * 2)
const auto num_sms = device_runtime->get_num_sms();
const auto tensor_map_buffer = torch::empty({num_sms * 4 * static_cast<int>(sizeof(CUtensorMap))},
a.first.options().dtype(torch::kByte));
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer,
cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
#endif
#if DG_TENSORMAP_COMPATIBLE
static void bf16_gemm_nt(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
// Shape must be `[M, K] @ [N, K].T`
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
// C/D must be N-major
check_major_type_cd(d);
// Type and shape checks
const auto [m , k ] = get_shape<2>(a);
const auto [n , k_] = get_shape<2>(b);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat);
// Early return for trivial cases
if (early_return(m, n, k, d, c))
return;
// Dispatch into different implements
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
} else if (arch_major == 10) {
sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void bf16_gemm_nn(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims);
}
static void bf16_gemm_tn(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims);
}
static void bf16_gemm_tt(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims);
}
static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const torch::Tensor& grouped_layout,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
// Shape must be `[M, K] @ [G, N, K].mT`
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(grouped_layout.is_contiguous());
// Type and shape checks
const auto [m, k] = get_shape<2>(a);
const auto [num_groups, n, k_] = get_shape<3>(b);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt);
// Layout checks
if (use_psum_layout) {
const auto [num_groups_] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(num_groups == num_groups_);
} else {
const auto [m__] = get_shape<1>(grouped_layout);
DG_HOST_ASSERT(m == m__);
DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value());
}
// D must be N-major
check_major_type_cd(d);
// Do nothing if empty
if (m == 0)
return;
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, compiled_dims,
use_psum_layout, expected_m_for_psum_layout);
} else if (arch_major == 10) {
sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout,
num_groups, m, n, k, major_a, major_b, compiled_dims,
use_psum_layout, expected_m_for_psum_layout);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const torch::Tensor& grouped_layout,
const std::string& compiled_dims,
const bool& use_psum_layout) {
m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2),
d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt);
}
static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const torch::Tensor& masked_m,
const int& expected_m, const std::string& compiled_dims) {
// Shape must be `[G, M, K] @ [G, N, K].mT`
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(masked_m.is_contiguous());
// Type and shape checks
const auto [num_groups, m, k] = get_shape<3>(a);
const auto [num_groups_, n, k_] = get_shape<3>(b);
const auto [num_groups__, m_, n_] = get_shape<3>(d);
const auto num_groups___ = static_cast<int>(masked_m.numel());
DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt);
// D must be N-major
check_major_type_cd(d);
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
} else if (arch_major == 10) {
sm100_m_grouped_bf16_gemm_masked(a, b, d, masked_m,
num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::optional<torch::Tensor>& c,
const std::string& compiled_dims) {
// Shape checks
const auto [num_groups, m, n] = get_shape<3>(d);
const auto [sum_k_ , m_] = get_shape<2>(a);
const auto [sum_k__, n_] = get_shape<2>(b);
const int sum_k = std::accumulate(ks.begin(), ks.end(), 0);
DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__);
// Contiguity checks
DG_HOST_ASSERT(a.is_contiguous());
DG_HOST_ASSERT(b.is_contiguous());
DG_HOST_ASSERT(d.is_contiguous());
DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous());
// Early return for trivial cases
if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c))
return;
// Dispatch implementation
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
} else if (arch_major == 10) {
sm100_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor,
cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
#endif
static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
// Shape must be `[M, K] @ [N, K].T`
const auto major_a = get_major_type_ab(a);
const auto major_b = get_major_type_ab(b);
// Type and shape checks
const auto [m , k ] = get_shape<2>(a);
const auto [n , k_] = get_shape<2>(b);
const auto [m_, n_] = get_shape<2>(d);
DG_HOST_ASSERT(m == m_ and n == n_ and k == k_);
// Early return for trivial cases
if (early_return(m, n, k, d, c))
return;
cublaslt_gemm(a, b, d, m, n, k, major_a, major_b, c.has_value());
}
static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a, b.transpose(0, 1), d, c);
}
static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c);
}
static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
const torch::Tensor& d, const std::optional<torch::Tensor>& c) {
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
}
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
// FP8 FP4 GEMMs
m.def("fp8_fp4_gemm_nt", &fp8_fp4_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_fp4_gemm_nn", &fp8_fp4_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_fp4_gemm_tn", &fp8_fp4_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("fp8_fp4_gemm_tt", &fp8_fp4_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "mn",
py::arg("disable_ue8m0_cast") = false);
m.def("m_grouped_fp8_fp4_gemm_nt_contiguous", &m_grouped_fp8_fp4_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false,
py::arg("use_psum_layout") = false,
py::arg("expected_m_for_psum_layout") = std::nullopt);
m.def("m_grouped_fp8_fp4_gemm_nn_contiguous", &m_grouped_fp8_fp4_gemm_nn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk",
py::arg("disable_ue8m0_cast") = false,
py::arg("use_psum_layout") = false);
m.def("m_grouped_fp8_fp4_gemm_nt_masked", &m_grouped_fp8_fp4_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt,
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("recipe") = std::make_tuple(1, 1, 128),
py::arg("compiled_dims") = "mn");
// FP8 GEMM alias names
m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt");
m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn");
m.attr("fp8_gemm_tn") = m.attr("fp8_fp4_gemm_tn");
m.attr("fp8_gemm_tt") = m.attr("fp8_fp4_gemm_tt");
m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous");
m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous");
m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked");
#endif
#if DG_TENSORMAP_COMPATIBLE
// BF16 GEMMs
m.def("bf16_gemm_nt", &bf16_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt,
py::arg("compiled_dims") = "nk");
m.def("bf16_gemm_nn", &bf16_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt,
py::arg("compiled_dims") = "nk");
m.def("bf16_gemm_tn", &bf16_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt,
py::arg("compiled_dims") = "mn");
m.def("bf16_gemm_tt", &bf16_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"),
py::arg("c") = std::nullopt,
py::arg("compiled_dims") = "mn");
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("compiled_dims") = "nk",
py::arg("use_psum_layout") = false,
py::arg("expected_m_for_psum_layout") = std::nullopt);
m.def("m_grouped_bf16_gemm_nn_contiguous", &m_grouped_bf16_gemm_nn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"),
py::arg("compiled_dims") = "nk",
py::arg("use_psum_layout") = false);
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
m.def("k_grouped_bf16_gemm_tn_contiguous", &k_grouped_bf16_gemm_tn_contiguous,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
py::arg("compiled_dims") = "mn");
#endif
// cuBLASLt GEMMs
m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
}
} // namespace deep_gemm::gemm

View File

@@ -0,0 +1,70 @@
#pragma once
#include "../utils/compatibility.hpp"
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp"
#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp"
#endif
namespace deep_gemm::hyperconnection {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
static void tf32_hc_prenorm_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& sqr_sum,
const std::optional<int>& num_splits) {
// A and B must be K-major, D must be N-major
DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K);
DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K);
check_major_type_cd(d);
// S must be contiguous
DG_HOST_ASSERT(sqr_sum.is_contiguous());
// Type and shape checks
const auto [m, k ] = get_shape<2>(a);
const auto [n, k_] = get_shape<2>(b);
if (num_splits.has_value()) {
const auto [num_splits_, m_, n_] = get_shape<3>(d);
const auto [num_splits__, m__] = get_shape<2>(sqr_sum);
DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1);
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
} else {
const auto [m_, n_] = get_shape<2>(d);
const auto [m__] = get_shape<1>(sqr_sum);
DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_);
}
DG_HOST_ASSERT(n > 0 and k > 0);
DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(b.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat);
// Do nothing if the problem is empty
if (m == 0)
return;
// Dispatch into different implements
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
} else if (arch_major == 10) {
sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE
m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm,
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"),
py::arg("num_splits") = std::nullopt);
#endif
}
} // namespace deep_gemm::hyperconnection

View File

@@ -0,0 +1,143 @@
#pragma once
#include "../jit_kernels/heuristics/runtime.hpp"
#include "../utils/layout.hpp"
#include "../utils/compatibility.hpp"
#if DG_TENSORMAP_COMPATIBLE
#include "../jit_kernels/impls/smxx_layout.hpp"
#endif
namespace deep_gemm::layout {
#if DG_TENSORMAP_COMPATIBLE
static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const std::variant<std::tuple<int, int, int>,
std::tuple<int, int>>& recipe,
const std::optional<int>& num_groups,
const std::optional<bool>& is_sfa,
const bool& disable_ue8m0_cast) {
const auto arch_major = device_runtime->get_arch_major();
// Get granularity MN/K from recipe
int gran_mn, gran_k;
if (auto p = std::get_if<std::tuple<int, int, int>>(&recipe)) {
DG_HOST_ASSERT(is_sfa.has_value());
gran_mn = is_sfa.value() ? std::get<0>(*p) : std::get<1>(*p);
gran_k = std::get<2>(*p);
} else if (auto p = std::get_if<std::tuple<int, int>>(&recipe)) {
DG_HOST_ASSERT(not is_sfa.has_value());
std::tie(gran_mn, gran_k) = *p;
} else {
DG_HOST_UNREACHABLE("Invalid recipe");
}
// Pre-transform checks
check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups);
// (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return get_mn_major_tma_aligned_tensor(sf);
// (FP32, 128, 128) on SM90: no need to transform, check SFB requirements
if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast))
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat);
// (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major
if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) {
DG_HOST_ASSERT(not disable_ue8m0_cast);
const auto broadcasted = gran_mn == 1 ? sf :
sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn));
return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted);
}
// (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major
if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10)
return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt);
DG_HOST_UNREACHABLE("Unknown SF transformation");
}
static std::tuple<torch::Tensor, torch::Tensor, int, int> transform_sf_pair_into_required_layout(
const torch::Tensor& sfa, const torch::Tensor& sfb,
const int& m, const int& n, const int& k,
std::optional<std::tuple<int, int, int>>& recipe,
const std::optional<std::tuple<int, int>>& recipe_a,
const std::optional<std::tuple<int, int>>& recipe_b,
const std::optional<int>& num_groups_a,
const std::optional<int>& num_groups_b,
const bool& disable_ue8m0_cast = false) {
// Use default recipe, if none is specified
if (not recipe_a.has_value() and not recipe.has_value())
recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type());
// Must be either 'recipe' or the 'recipe_a' + 'recipe_b' pair.
DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value());
DG_HOST_ASSERT(recipe_a.has_value() != recipe.has_value());
// Transform SFA and SFB layout
const auto transformed_sfa = recipe.has_value() ? transform_sf_into_required_layout(sfa, m, k, recipe.value(), num_groups_a, true, disable_ue8m0_cast)
: transform_sf_into_required_layout(sfa, m, k, recipe_a.value(), num_groups_a, std::nullopt, disable_ue8m0_cast);
const auto transformed_sfb = recipe.has_value() ? transform_sf_into_required_layout(sfb, n, k, recipe.value(), num_groups_b, false, disable_ue8m0_cast)
: transform_sf_into_required_layout(sfb, n, k, recipe_b.value(), num_groups_b, std::nullopt, disable_ue8m0_cast);
const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value());
const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value());
return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b);
}
static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf,
const std::vector<int>& ks,
const torch::Tensor& ks_tensor,
const std::tuple<int, int, int>& recipe) {
DG_HOST_ASSERT(sf.dim() == 2);
DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1);
const int gran_k = std::get<2>(recipe);
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
const auto arch_major = device_runtime->get_arch_major();
// FP32 on SM90
if (sf.scalar_type() == torch::kFloat and arch_major == 9)
return get_mn_major_tma_aligned_tensor(sf);
// FP32 on SM100
if (sf.scalar_type() == torch::kFloat and arch_major == 10)
return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks, gran_k);
// INT on SM100
if (sf.scalar_type() == torch::kInt and arch_major == 10)
DG_HOST_UNREACHABLE("Unimplemented");
DG_HOST_UNREACHABLE("Unknown cases");
}
#endif
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
py::arg("num_groups") = std::nullopt,
py::arg("is_sfa") = std::nullopt,
py::arg("disable_ue8m0_cast") = false);
m.def("get_tma_aligned_size", &get_tma_aligned_size);
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
#endif
m.def("set_mk_alignment_for_contiguous_layout", [&](const int& new_value) {
heuristics_runtime->set_mk_alignment_for_contiguous_layout(new_value);
});
m.def("get_mk_alignment_for_contiguous_layout", [&]() {
return heuristics_runtime->get_mk_alignment_for_contiguous_layout();
});
m.def("get_theoretical_mk_alignment_for_contiguous_layout", [&](const std::optional<int>& expected_m) {
return heuristics_runtime->get_theoretical_mk_alignment_for_contiguous_layout(expected_m);
}, py::arg("expected_m") = std::nullopt);
}
} // namespace deep_gemm::layout

235
third_party/DeepGEMM/csrc/apis/mega.hpp vendored Normal file
View File

@@ -0,0 +1,235 @@
#pragma once
#include <functional>
#include <pybind11/functional.h>
#if DG_TENSORMAP_COMPATIBLE
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp"
namespace deep_gemm::mega {
static int get_token_alignment_for_mega_moe() {
return layout::kLCMCandidateBlockM;
}
static std::tuple<int64_t, std::function<std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor>(const torch::Tensor&)>>
get_symm_buffer_size_for_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const bool& use_fp8_dispatch, const std::string& activation) {
DG_HOST_ASSERT(num_experts % num_ranks == 0);
// Workspace bytes
const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk);
// Layouts
const auto fp8_token_layout = layout::Data(hidden);
const auto bf16_token_layout = layout::Data(hidden * 2);
const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden);
const auto fp8_sf_layout = layout::Data(hidden / 32);
const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32);
const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false);
const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false);
const auto l1_topk_weights_layout = layout::Data(sizeof(float), false);
// Input buffers
const auto input_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_tokens_per_rank,
workspace.get_end_ptr());
const auto input_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_max_tokens_per_rank,
input_token_buffer.get_end_ptr());
const auto input_topk_idx_buffer = layout::Buffer(
input_topk_idx_layout, 1, num_max_tokens_per_rank,
input_sf_buffer.get_end_ptr());
const auto input_topk_weights_buffer = layout::Buffer(
input_topk_weights_layout, 1, num_max_tokens_per_rank,
input_topk_idx_buffer.get_end_ptr());
// Buffer configs
const auto num_max_pool_tokens = static_cast<int>(workspace.num_max_pool_tokens);
int num_max_padded_sf_pool_tokens = 0;
for (int block_m: layout::kCandidateBlockM) {
num_max_padded_sf_pool_tokens = std::max(
num_max_padded_sf_pool_tokens,
layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m)
);
}
// L1 input buffer
const auto l1_token_buffer = layout::Buffer(
fp8_token_layout, 1, num_max_pool_tokens,
input_topk_weights_buffer.get_end_ptr());
const auto l1_sf_buffer = layout::Buffer(
fp8_sf_layout, 1, num_max_padded_sf_pool_tokens,
l1_token_buffer.get_end_ptr());
const auto l1_topk_weights_buffer = layout::Buffer(
l1_topk_weights_layout, 1, num_max_pool_tokens,
l1_sf_buffer.get_end_ptr());
// L2 input buffer
const auto l2_token_buffer = layout::Buffer(
fp8_intermediate_token_layout, 1, num_max_pool_tokens,
l1_topk_weights_buffer.get_end_ptr());
const auto l2_sf_buffer = layout::Buffer(
fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens,
l2_token_buffer.get_end_ptr());
// Combine input buffer: BF16 tokens for cross-rank combine
const auto combine_token_buffer = layout::Buffer(
bf16_token_layout, num_topk, num_max_tokens_per_rank,
l2_sf_buffer.get_end_ptr());
// Check SF buffer requirements
DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0);
DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0);
// Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer
// NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major
auto slice_input_buffers = [=](const torch::Tensor& buffer) {
auto x = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_token_buffer.base)),
{num_max_tokens_per_rank, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto x_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_sf_buffer.base)),
{num_max_tokens_per_rank, hidden / 128},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto topk_idx = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_idx_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kInt64).device(buffer.device()));
auto topk_weights = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(input_topk_weights_buffer.base)),
{num_max_tokens_per_rank, num_topk},
torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device()));
auto l1_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_token_buffer.base)),
{num_max_pool_tokens, hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l1_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l1_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
auto l2_acts = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_token_buffer.base)),
{num_max_pool_tokens, intermediate_hidden},
torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device()));
auto l2_acts_sf = torch::from_blob(
math::advance_ptr(buffer.data_ptr(), reinterpret_cast<int64_t>(l2_sf_buffer.base)),
{num_max_padded_sf_pool_tokens, intermediate_hidden / 128},
{1, num_max_padded_sf_pool_tokens},
torch::TensorOptions().dtype(torch::kInt).device(buffer.device()));
return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf);
};
return {reinterpret_cast<int64_t>(combine_token_buffer.get_end_ptr()), slice_input_buffers};
}
static void fp8_fp4_mega_moe(
const torch::Tensor& y,
const std::tuple<torch::Tensor, torch::Tensor>& l1_weights_tuple,
const std::tuple<torch::Tensor, torch::Tensor>& l2_weights_tuple,
const std::optional<torch::Tensor>& cumulative_local_expert_recv_stats,
const torch::Tensor& sym_buffer,
const std::vector<int64_t>& sym_buffer_ptrs, const int& rank_idx,
const int& num_max_tokens_per_rank,
const int& num_experts, const int& num_topk,
const std::tuple<int, int, int>& recipe,
const std::string& activation,
const std::optional<float>& activation_clamp_opt,
const bool& fast_math
) {
const auto [l1_weights, l1_weights_sf] = l1_weights_tuple;
const auto [l2_weights, l2_weights_sf] = l2_weights_tuple;
// Config checks
const auto num_tokens = static_cast<int>(y.size(0));
const auto [rm, rn, rk] = recipe;
DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32);
DG_HOST_ASSERT(activation == "swiglu");
// Activation checks
const auto activation_clamp =
activation_clamp_opt.value_or(std::numeric_limits<float>::infinity());
DG_HOST_ASSERT(activation_clamp >= 0);
// Tensor checks
DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K);
DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K);
const auto arch_major = device_runtime->get_arch_major();
const auto [num_experts_per_rank, intermediate_hidden_2, hidden] =
check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major);
const auto [num_experts_per_rank_, hidden_, intermediate_hidden] =
check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major);
DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank);
DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_);
DG_HOST_ASSERT(hidden == hidden_);
DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden);
DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous());
// Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment
constexpr int kGranMN = 1, kGranK = 32;
check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK,
num_experts_per_rank, true, false, torch::kInt);
// Check stats counter
if (cumulative_local_expert_recv_stats.has_value()) {
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank);
DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous());
}
// Check buffer bytes
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts_ = num_experts_per_rank * num_ranks;
const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe(
num_ranks, num_experts,
num_max_tokens_per_rank, num_topk,
hidden, intermediate_hidden,
true, activation);
DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast<size_t>(num_required_bytes));
DG_HOST_ASSERT(num_experts == num_experts_);
// Already registered tensors
const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer);
// Dispatch into different architectures
if (arch_major == 10) {
sm100_fp8_fp4_mega_moe(y,
l1_acts, l1_acts_sf,
l2_acts, l2_acts_sf,
l1_weights, l2_weights,
l1_weights_sf, l2_weights_sf,
cumulative_local_expert_recv_stats,
sym_buffer_ptrs,
rank_idx, num_max_tokens_per_rank,
num_experts_per_rank,
num_tokens, num_topk,
hidden, intermediate_hidden,
activation_clamp, fast_math);
} else {
DG_HOST_UNREACHABLE("Unsupported architecture");
}
// Zero the entire symmetric buffer for debug mode
// NOTES: caller must re-copy inputs into the buffer before each kernel call
if (get_env<int>("DG_COMM_KERNEL_DEBUG"))
sym_buffer.zero_();
}
static void register_apis(pybind11::module_& m) {
#if DG_TENSORMAP_COMPATIBLE
m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe);
m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe);
m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe);
#endif
}
} // namespace deep_gemm::mega

View File

@@ -0,0 +1,51 @@
#pragma once
#if DG_TENSORMAP_COMPATIBLE
#include "../jit/compiler.hpp"
#endif
#include "../jit/device_runtime.hpp"
#include "../jit_kernels/heuristics/runtime.hpp"
namespace deep_gemm::runtime {
static void register_apis(pybind11::module_& m) {
m.def("set_num_sms", [&](const int& new_num_sms) {
device_runtime->set_num_sms(new_num_sms);
});
m.def("get_num_sms", [&]() {
return device_runtime->get_num_sms();
});
m.def("set_tc_util", [&](const int& new_tc_util) {
device_runtime->set_tc_util(new_tc_util);
});
m.def("get_tc_util", [&]() {
return device_runtime->get_tc_util();
});
m.def("set_pdl", [&](const bool& new_enable_pdl) {
device_runtime->set_pdl(new_enable_pdl);
});
m.def("get_pdl", [&]() {
return device_runtime->get_pdl();
});
m.def("set_ignore_compile_dims", [&](const bool& new_value) {
heuristics_runtime->set_ignore_compile_dims(new_value);
});
m.def("set_block_size_multiple_of", [&](const std::variant<int, std::tuple<int, int>>& new_value) {
if (std::holds_alternative<int>(new_value)) {
auto x = std::get<int>(new_value);
heuristics_runtime->set_block_size_multiple_of(x, x);
} else {
auto [x, y] = std::get<std::tuple<int, int>>(new_value);
heuristics_runtime->set_block_size_multiple_of(x, y);
}
});
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
#if DG_TENSORMAP_COMPATIBLE
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
KernelRuntime::prepare_init(cuda_home_path_by_python);
IncludeParser::prepare_init(library_root_path);
#endif
});
}
} // namespace deep_gemm::runtime

View File

@@ -0,0 +1,35 @@
// GEMM kernels
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
#include <deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh>
// Attention kernels
#include <deep_gemm/impls/sm90_fp8_mqa_logits.cuh>
#include <deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp4_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp8_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh>
#include <deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh>
// Einsum kernels
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
// Hyperconnection kernels
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh>
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh>
// Layout kernels
#include <deep_gemm/impls/smxx_layout.cuh>
#include <deep_gemm/impls/smxx_clean_logits.cuh>
// Mega kernels
#include <deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh>
using namespace deep_gemm;
int main() {
return 0;
}

31
third_party/DeepGEMM/csrc/jit/cache.hpp vendored Normal file
View File

@@ -0,0 +1,31 @@
#pragma once
#include <filesystem>
#include <memory>
#include <unordered_map>
#include "kernel_runtime.hpp"
namespace deep_gemm {
class KernelRuntimeCache {
std::unordered_map<std::string, std::shared_ptr<KernelRuntime>> cache;
public:
// TODO: consider cache capacity
KernelRuntimeCache() = default;
std::shared_ptr<KernelRuntime> get(const std::filesystem::path& dir_path) {
// Hit the runtime cache
if (const auto iterator = cache.find(dir_path); iterator != cache.end())
return iterator->second;
if (KernelRuntime::check_validity(dir_path))
return cache[dir_path] = std::make_shared<KernelRuntime>(dir_path);
return nullptr;
}
};
static auto kernel_runtime_cache = std::make_shared<KernelRuntimeCache>();
} // namespace deep_gemm

View File

@@ -0,0 +1,362 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <fcntl.h>
#include <filesystem>
#include <fstream>
#include <nvrtc.h>
#include <regex>
#include <string>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/hash.hpp"
#include "../utils/lazy_init.hpp"
#include "../utils/system.hpp"
#include "cache.hpp"
#include "device_runtime.hpp"
#include "include_parser.hpp"
namespace deep_gemm {
class Compiler {
public:
static std::filesystem::path library_root_path;
static std::filesystem::path library_include_path;
static std::filesystem::path cuda_home;
static std::filesystem::path cuobjdump_path;
static void prepare_init(const std::string& library_root_path,
const std::string& cuda_home_path_by_python) {
Compiler::library_root_path = library_root_path;
Compiler::library_include_path = Compiler::library_root_path / "include";
Compiler::cuda_home = cuda_home_path_by_python;
Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump";
}
std::string signature, flags;
std::filesystem::path cache_dir_path;
Compiler() {
// Check `prepare_init`
DG_HOST_ASSERT(not library_root_path.empty());
DG_HOST_ASSERT(not library_include_path.empty());
DG_HOST_ASSERT(not cuda_home.empty());
DG_HOST_ASSERT(not cuobjdump_path.empty());
// Cache settings
cache_dir_path = std::filesystem::path(get_env<std::string>("HOME")) / ".deep_gemm";
if (const auto env_cache_dir_path = get_env<std::string>("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty())
cache_dir_path = env_cache_dir_path;
// The compiler flags applied to all derived compilers
signature = "unknown-compiler";
flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 "
"--ptxas-options=--register-usage-level=10",
get_env<int>("DG_JIT_CPP_STANDARD", 20));
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0) or get_env("DG_JIT_PTXAS_CHECK", 0))
flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage";
if (get_env("DG_JIT_WITH_LINEINFO", 0))
flags += " -Xcompiler -rdynamic -lineinfo";
}
virtual ~Compiler() = default;
std::filesystem::path make_tmp_dir() const {
return make_dirs(cache_dir_path / "tmp");
}
static void fsync_path(const std::filesystem::path& path) {
const auto fd = ::open(path.c_str(), O_RDONLY);
if (fd >= 0) {
::fsync(fd);
::close(fd);
}
}
// Recursively fsync a directory: files and subdirectories first (bottom-up), then the directory itself
// NOTES: ensures data and directory entries are visible on other nodes in distributed filesystems
static void fsync_dir(const std::filesystem::path& dir_path) { // NOLINT(*-no-recursion)
for (const auto& entry: std::filesystem::directory_iterator(dir_path)) {
if (entry.is_directory())
fsync_dir(entry.path());
else if (entry.is_regular_file())
fsync_path(entry.path());
}
fsync_path(dir_path);
}
static void put(const std::filesystem::path& path, const std::string& data) {
std::ofstream out(path, std::ios::binary);
DG_HOST_ASSERT(out.write(data.data(), data.size()));
out.close();
// NOTES: fsync to ensure the data is visible to other processes (e.g., NVCC)
// on distributed filesystems, where `close()` alone does not guarantee persistence
fsync_path(path);
}
std::shared_ptr<KernelRuntime> build(const std::string& name, const std::string& code) const {
const auto kernel_signature = fmt::format("{}$${}$${}$${}", name, signature, flags, code);
const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature));
// Hit the runtime cache
if (const auto runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr)
return runtime;
// Compile into a temporary directory, then atomically rename the whole directory
// NOTES: renaming a directory is atomic on both local and distributed filesystems,
// avoiding the stale inode issue that occurs when renaming individual files
const auto tmp_dir_path = make_tmp_dir() / get_uuid();
make_dirs(tmp_dir_path);
// Compile into the temporary directory
const auto tmp_cubin_path = tmp_dir_path / "kernel.cubin";
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_PTX")) {
const auto tmp_ptx_path = tmp_dir_path / "kernel.ptx";
compile(code, tmp_dir_path, tmp_cubin_path, tmp_ptx_path);
} else {
compile(code, tmp_dir_path, tmp_cubin_path);
}
// Disassemble if needed
if (get_env<int>("DG_JIT_DUMP_ASM") or get_env<int>("DG_JIT_DUMP_SASS")) {
const auto tmp_sass_path = tmp_dir_path / "kernel.sass";
disassemble(tmp_cubin_path, tmp_sass_path);
}
// Fsync before rename to ensure visibility on distributed filesystems
fsync_dir(tmp_dir_path);
// Atomically rename the temporary directory to the final cache path
// NOTES: if another rank already created dir_path, rename will fail — that's fine
make_dirs(dir_path.parent_path());
std::error_code error_code;
std::filesystem::rename(tmp_dir_path, dir_path, error_code);
if (error_code) {
// Another rank beat us, then clean up our dir and use the existing one
// NOTES: avoid `std::filesystem::remove_all` here — it can segfault on
// distributed filesystems, when concurrent processes operate
// on the same parent directory, causing stale directory entries
safe_remove_all(tmp_dir_path);
}
// Put into the runtime cache
const auto runtime = kernel_runtime_cache->get(dir_path);
DG_HOST_ASSERT(runtime != nullptr);
return runtime;
}
static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) {
// Disassemble the CUBIN file to SASS
const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str());
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running cuobjdump command: %s\n", command.c_str());
const auto [return_code, output] = call_external_command(command);
if (return_code != 0) {
printf("cuobjdump failed: %s\n", output.c_str());
DG_HOST_ASSERT(false and "cuobjdump failed");
}
}
virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional<std::filesystem::path> &ptx_path = std::nullopt) const = 0;
};
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home);
DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path);
class NVCCCompiler final: public Compiler {
std::filesystem::path nvcc_path;
std::pair<int, int> get_nvcc_version() const {
DG_HOST_ASSERT(std::filesystem::exists(nvcc_path));
// Call the version command
const auto command = std::string(nvcc_path) + " --version";
const auto [return_code, output] = call_external_command(command);
DG_HOST_ASSERT(return_code == 0);
// The version should be at least 12.3, for the best performance with 12.9
int major, minor;
std::smatch match;
DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))")));
std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor);
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3");
if (major == 12 and minor < 9)
printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n");
return {major, minor};
}
public:
NVCCCompiler() {
// Override the compiler signature
nvcc_path = cuda_home / "bin" / "nvcc";
if (const auto env_nvcc_path = get_env<std::string>("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty())
nvcc_path = env_nvcc_path;
const auto [nvcc_major, nvcc_minor] = get_nvcc_version();
signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor);
// The override the compiler flags
// Only NVCC >= 12.9 supports arch-specific family suffix
const auto arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9);
flags = fmt::format("{} -I{} --gpu-architecture=sm_{} "
"--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi "
"-O3 --expt-relaxed-constexpr --expt-extended-lambda",
flags, library_include_path.c_str(), arch);
}
void compile(const std::string &code, const std::filesystem::path& dir_path,
const std::filesystem::path &cubin_path,
const std::optional<std::filesystem::path> &ptx_path) const override {
// Write the code into the cache directory
const auto code_path = dir_path / "kernel.cu";
put(code_path, code);
// Compile
// Avoid cwd files shadowing C++ standard library headers
const auto compile_dir = make_tmp_dir();
const auto command = fmt::format("cd {} && {} {} -cubin -o {} {}",
compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC command: %s\n", command.c_str());
const auto [return_code, output] = call_external_command(command);
if (return_code != 0) {
printf("NVCC compilation failed: %s\n", output.c_str());
DG_HOST_ASSERT(false and "NVCC compilation failed");
}
// Compile to PTX if needed
if (ptx_path.has_value()) {
const auto ptx_command = fmt::format("cd {} && {} {} -ptx -o {} {}",
compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags);
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0))
printf("Running NVCC PTX command: %s\n", ptx_command.c_str());
const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command);
if (ptx_return_code != 0) {
printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str());
DG_HOST_ASSERT(false and "NVCC PTX compilation failed");
}
}
// Check local memory usage
if (get_env("DG_JIT_PTXAS_CHECK", 0))
DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)")));
// Print PTXAS log
if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0))
printf("%s", output.c_str());
}
};
class NVRTCCompiler final: public Compiler {
public:
NVRTCCompiler() {
// Override the compiler signature
int major, minor;
DG_NVRTC_CHECK(nvrtcVersion(&major, &minor));
signature = fmt::format("NVRTC{}.{}", major, minor);
DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVRTC version should be >= 12.3");
// Build include directories list
std::string include_dirs;
include_dirs += fmt::format("-I{} ", library_include_path.string());
include_dirs += fmt::format("-I{} ", (cuda_home / "include").string());
// Add PCH support for version 12.8 and above
// NOTES: PCH is vital for compilation speed
std::string pch_flags;
if (major > 12 or minor >= 8) {
pch_flags = "--pch ";
if (get_env<int>("DG_JIT_DEBUG", 0))
pch_flags += "--pch-verbose=true ";
}
// Override the compiler flags
// Only NVRTC >= 12.9 supports arch-specific family suffix
const auto arch = device_runtime->get_arch(false, major > 12 or minor >= 9);
flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128",
flags, include_dirs, arch, pch_flags);
}
void compile(const std::string &code, const std::filesystem::path& dir_path,
const std::filesystem::path &cubin_path,
const std::optional<std::filesystem::path> &ptx_path) const override {
// Write the code into the cache directory
const auto code_path = dir_path / "kernel.cu";
put(code_path, code);
// Parse compilation options
std::istringstream iss(flags);
std::vector<std::string> options;
std::string option;
while (iss >> option)
options.push_back(option);
// Convert to C-style string array for NVRTC
std::vector<const char*> option_cstrs;
for (const auto& opt: options)
option_cstrs.push_back(opt.c_str());
// Print compiler command if requested
if (get_env<int>("DG_JIT_DEBUG", 0) or get_env<int>("DG_JIT_PRINT_COMPILER_COMMAND", 0)) {
printf("Compiling JIT runtime with NVRTC options: ");
for (const auto& opt: options)
printf("%s ", opt.c_str());
printf("\n");
}
// Create NVRTC program and compile
nvrtcProgram program;
DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr));
const auto compile_result = nvrtcCompileProgram(program, static_cast<int>(option_cstrs.size()), option_cstrs.data());
// Get and print compiler log
size_t log_size;
DG_NVRTC_CHECK(nvrtcGetProgramLogSize(program, &log_size));
if (get_env<int>("DG_JIT_DEBUG", 0) or compile_result != NVRTC_SUCCESS) {
if (compile_result != NVRTC_SUCCESS)
DG_HOST_ASSERT(log_size > 1);
if (log_size > 1) {
std::string compilation_log(log_size, '\0');
DG_NVRTC_CHECK(nvrtcGetProgramLog(program, compilation_log.data()));
printf("NVRTC log: %s\n", compilation_log.c_str());
}
}
if (ptx_path.has_value()) {
// Get PTX size and data if needed
size_t ptx_size;
DG_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size));
std::string ptx_data(ptx_size, '\0');
DG_NVRTC_CHECK(nvrtcGetPTX(program, ptx_data.data()));
// Write into the file system
put(ptx_path.value(), ptx_data);
}
// Get CUBIN size and data
size_t cubin_size;
DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size));
std::string cubin_data(cubin_size, '\0');
DG_NVRTC_CHECK(nvrtcGetCUBIN(program, cubin_data.data()));
// Write into the file system
put(cubin_path, cubin_data);
// Cleanup
DG_NVRTC_CHECK(nvrtcDestroyProgram(&program));
}
};
static auto compiler = LazyInit<Compiler>([]() -> std::shared_ptr<Compiler> {
if (get_env<int>("DG_JIT_USE_NVRTC", 0)) {
return std::make_shared<NVRTCCompiler>();
} else {
return std::make_shared<NVCCCompiler>();
}
});
} // namespace deep_gemm

View File

@@ -0,0 +1,138 @@
#pragma once
#include <cublasLt.h>
#include <torch/version.h>
#include <ATen/cuda/CUDAContext.h>
#include "../utils/exception.hpp"
#include "../utils/lazy_init.hpp"
#define PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3))
namespace deep_gemm {
class DeviceRuntime {
int num_sms = 0, tc_util = 0;
bool enable_pdl = false;
std::shared_ptr<cudaDeviceProp> cached_prop;
// cuBLASLt utils
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
public:
// Create the cuBLASLt handle ourselves
cublasLtHandle_t cublaslt_handle;
torch::Tensor cublaslt_workspace;
bool use_pytorch_managed_cublaslt_handle;
bool use_temp_cublaslt_workspace;
explicit DeviceRuntime() {
// Whether to use PyTorch cuBLASLt
// By default, we don't use it,
// as `at::cuda::getCurrentCUDABlasLtHandle` has large CPU overhead with some PyTorch versions
use_pytorch_managed_cublaslt_handle = get_env<int>("DG_USE_PYTORCH_CUBLASLT_HANDLE", 0) > 0;
#if not PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE
DG_HOST_ASSERT(not use_pytorch_managed_cublaslt_handle and "PyTorch does not support to get cuBLASLt handle");
#endif
// Whether to create workspace tensor on each call instead of holding one.
// Enabled by compute-sanitizer tests, which trigger `cudaErrorCudartUnloading`
// when the workspace tensor is destructed after CUDA driver shutdown.
use_temp_cublaslt_workspace = get_env<int>("DG_USE_TEMP_CUBLASLT_WORKSPACE", 0) > 0;
if (not use_pytorch_managed_cublaslt_handle)
DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle));
if (not use_temp_cublaslt_workspace)
cublaslt_workspace = torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA));
}
~DeviceRuntime() noexcept(false) {
if (not use_pytorch_managed_cublaslt_handle)
DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle));
}
cublasLtHandle_t get_cublaslt_handle() const {
#if PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE
if (use_pytorch_managed_cublaslt_handle)
return at::cuda::getCurrentCUDABlasLtHandle();
#endif
// Self-managed handle
return cublaslt_handle;
}
torch::Tensor get_cublaslt_workspace() const {
if (use_temp_cublaslt_workspace)
return torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA));
return cublaslt_workspace;
}
std::shared_ptr<cudaDeviceProp> get_prop() {
if (cached_prop == nullptr) {
int device_idx;
cudaDeviceProp prop;
DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx));
DG_CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&prop, device_idx));
cached_prop = std::make_shared<cudaDeviceProp>(prop);
}
return cached_prop;
}
std::pair<int, int> get_arch_pair() {
const auto prop = get_prop();
return {prop->major, prop->minor};
}
std::string get_arch(const bool& number_only = false,
const bool& support_arch_family = false) {
const auto [major, minor] = get_arch_pair();
if (major == 10 and minor != 1) {
if (number_only)
return "100";
return support_arch_family ? "100f" : "100a";
}
return std::to_string(major * 10 + minor) + (number_only ? "" : "a");
}
int get_arch_major() {
return get_arch_pair().first;
}
void set_num_sms(const int& new_num_sms) {
DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount);
num_sms = new_num_sms;
}
int get_num_sms() {
if (num_sms == 0)
num_sms = get_prop()->multiProcessorCount;
return num_sms;
}
int get_l2_cache_size() {
return get_prop()->l2CacheSize;
}
void set_tc_util(const int& new_tc_util) {
DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100);
tc_util = new_tc_util;
}
int get_tc_util() const {
return tc_util == 0 ? 100 : tc_util;
}
void set_pdl(const bool& new_enable_pdl) {
enable_pdl = new_enable_pdl;
}
bool get_pdl() const {
return enable_pdl;
}
};
static auto device_runtime = LazyInit<DeviceRuntime>([](){ return std::make_shared<DeviceRuntime>(); });
} // namespace deep_gemm

222
third_party/DeepGEMM/csrc/jit/handle.hpp vendored Normal file
View File

@@ -0,0 +1,222 @@
#pragma once
#include <cuda.h>
#include <cuda_runtime.h>
#include <dlfcn.h>
#include <filesystem>
#include "../utils/exception.hpp"
#include "../utils/compatibility.hpp"
namespace deep_gemm {
// Lazy loading all driver symbols
static void* get_driver_handle() {
static void* handle = nullptr;
if (handle == nullptr) {
handle = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_LOCAL);
DG_HOST_ASSERT(handle != nullptr and "Failed to load CUDA driver `libcuda.so.1`");
}
return handle;
}
// Macro to define wrapper functions named `lazy_cu{API name}`
#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \
template <typename... Args> \
static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \
using FuncType = decltype(&(name)); \
static FuncType func = nullptr; \
if (func == nullptr) { \
func = reinterpret_cast<FuncType>(dlsym(get_driver_handle(), #name)); \
DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA driver API"); \
} \
return func(std::forward<decltype(args)>(args)...); \
}
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorName);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorString);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryLoadFromFile);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryUnload);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuKernelGetFunction);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled);
#if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API)
// Use CUDA runtime API
using LibraryHandle = cudaLibrary_t;
using KernelHandle = cudaKernel_t;
using LaunchConfigHandle = cudaLaunchConfig_t;
using LaunchAttrHandle = cudaLaunchAttribute;
#define DG_CUDA_UNIFIED_CHECK DG_CUDA_RUNTIME_CHECK
static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name,
LibraryHandle *library_opt = nullptr) {
LibraryHandle library;
KernelHandle kernel{};
DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, func_name.c_str()));
if (library_opt != nullptr)
*library_opt = library;
return kernel;
}
static void unload_library(const LibraryHandle& library) {
const auto error = cudaLibraryUnload(library);
DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading);
}
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
const cudaStream_t& stream, const int& smem_size,
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) {
if (smem_size > 0)
DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
LaunchConfigHandle config;
config.gridDim = grid_dim;
config.blockDim = block_dim;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
// Create attributes
// NOTES: must use `static` or the `attr` will be deconstructed
static LaunchAttrHandle attrs[2];
config.numAttrs = 0;
config.attrs = attrs;
// Cluster size
if (cluster_dim > 1) {
auto& attr = attrs[config.numAttrs ++];
attr.id = cudaLaunchAttributeClusterDimension;
attr.val.clusterDim = {static_cast<unsigned>(cluster_dim), 1, 1};
}
// Dependent kernel launch
if (enable_pdl) {
auto& attr = attrs[config.numAttrs ++];
attr.id = cudaLaunchAttributeProgrammaticStreamSerialization;
attr.val.programmaticStreamSerializationAllowed = 1;
}
return config;
}
template<typename... ActTypes>
static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) {
void *ptr_args[] = { &args... };
return cudaLaunchKernelExC(&config, kernel, ptr_args);
}
#else
// Use CUDA driver API
using KernelHandle = CUfunction;
using LaunchConfigHandle = CUlaunchConfig;
using LaunchAttrHandle = CUlaunchAttribute;
// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4
#if CUDA_VERSION >= 12040
#define DG_JIT_USE_LIBRARY_ENUM_KERNELS
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount);
DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels);
using LibraryHandle = CUlibrary;
#else
using LibraryHandle = CUmodule;
#endif
#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK
static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name,
LibraryHandle *library_opt = nullptr) {
LibraryHandle library;
KernelHandle kernel;
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0));
unsigned int num_kernels;
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryGetKernelCount(&num_kernels, library));
if (num_kernels != 1) {
const auto dir_path = cubin_path.parent_path();
printf("Corrupted JIT cache directory (expected 1 kernel, found %u): %s, "
"please run `rm -rf %s` and restart your task.\n",
num_kernels, dir_path.c_str(), dir_path.c_str());
DG_HOST_ASSERT(false and "Corrupted JIT cache directory");
}
CUkernel cu_kernel;
DG_CUDA_DRIVER_CHECK(lazy_cuLibraryEnumerateKernels(&cu_kernel, 1, library));
DG_CUDA_DRIVER_CHECK(lazy_cuKernelGetFunction(&kernel, cu_kernel));
#else
DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str()));
DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str()));
#endif
if (library_opt != nullptr)
*library_opt = library;
return kernel;
}
static void unload_library(const LibraryHandle& library) {
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
const auto error = lazy_cuLibraryUnload(library);
#else
const auto error = lazy_cuModuleUnload(library);
#endif
DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED);
}
static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel,
const cudaStream_t& stream, const int& smem_size,
const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) {
if (smem_size > 0)
DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size));
LaunchConfigHandle config;
config.gridDimX = grid_dim.x;
config.gridDimY = grid_dim.y;
config.gridDimZ = grid_dim.z;
config.blockDimX = block_dim.x;
config.blockDimY = block_dim.y;
config.blockDimZ = block_dim.z;
config.sharedMemBytes = smem_size;
config.hStream = stream;
// Create attributes
// NOTES: must use `static` or the `attr` will be deconstructed
static LaunchAttrHandle attrs[2];
config.numAttrs = 0;
config.attrs = attrs;
// Cluster size
if (cluster_dim > 1) {
auto& attr = attrs[config.numAttrs ++];
attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
attr.value.clusterDim.x = static_cast<unsigned>(cluster_dim);
attr.value.clusterDim.y = 1;
attr.value.clusterDim.z = 1;
}
// Dependent kernel launch
if (enable_pdl) {
auto& attr = attrs[config.numAttrs ++];
attr.id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION;
attr.value.programmaticStreamSerializationAllowed = 1;
}
return config;
}
template<typename... ActTypes>
static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) {
void *ptr_args[] = { &args... };
return lazy_cuLaunchKernelEx(&config, kernel, ptr_args, nullptr);
}
#endif
} // namespace deep_gemm

View File

@@ -0,0 +1,80 @@
#pragma once
#include <filesystem>
#include <regex>
#include <string>
#include <vector>
#include "../utils/format.hpp"
#include "../utils/system.hpp"
namespace deep_gemm {
class IncludeParser {
std::unordered_map<std::string, std::optional<std::string>> cache;
static std::vector<std::string> get_includes(const std::string& code, const std::filesystem::path& file_path = "") {
std::vector<std::string> includes;
const std::regex pattern(R"(#\s*include\s*[<"][^>"]+[>"])");
std::sregex_iterator iter(code.begin(), code.end(), pattern);
const std::sregex_iterator end;
// TODO: parse relative paths as well
for (; iter != end; ++ iter) {
const auto include_str = iter->str();
const int len = include_str.length();
if (include_str.substr(0, 10) == "#include <" and include_str[len - 1] == '>' and include_str[10] != ' ' and include_str[len - 2] != ' ') {
std::string filename = include_str.substr(10, len - 11);
if (filename.substr(0, 9) == "deep_gemm") // We only parse `<deep_gemm/*>`
includes.push_back(filename);
} else {
std::string error_info = fmt::format("Non-standard include: {}", include_str);
if (file_path != "")
error_info += fmt::format(" ({})", file_path.string());
DG_HOST_UNREACHABLE(error_info);
}
}
return includes;
}
public:
static std::filesystem::path library_include_path;
static void prepare_init(const std::string& library_root_path) {
library_include_path = std::filesystem::path(library_root_path) / "include";
}
std::string get_hash_value(const std::string& code, const bool& exclude_code = true) {
std::stringstream ss;
for (const auto& i: get_includes(code))
ss << get_hash_value_by_path(library_include_path / i) << "$";
if (not exclude_code)
ss << "#" << get_hex_digest(code);
return get_hex_digest(ss.str());
}
std::string get_hash_value_by_path(const std::filesystem::path& path) {
// Check whether hit in cache
// ReSharper disable once CppUseAssociativeContains
if (cache.count(path) > 0) {
const auto opt = cache[path];
if (not opt.has_value())
DG_HOST_UNREACHABLE(fmt::format("Circular include may occur: {}", path.string()));
return opt.value();
}
// Read file and calculate hash recursively
std::ifstream in(path);
if (not in.is_open())
DG_HOST_UNREACHABLE(fmt::format("Failed to open: {}", path.string()));
std::string code((std::istreambuf_iterator<char>(in)), std::istreambuf_iterator<char>());
cache[path] = std::nullopt;
return (cache[path] = get_hash_value(code, false)).value();
}
};
DG_DECLARE_STATIC_VAR_IN_CLASS(IncludeParser, library_include_path);
static auto include_parser = std::make_shared<IncludeParser>();
} // namespace deep_gemm

View File

@@ -0,0 +1,165 @@
#pragma once
#include <chrono>
#include "../utils/exception.hpp"
#include "../utils/format.hpp"
#include "../utils/system.hpp"
#include "device_runtime.hpp"
#include "handle.hpp"
#include "include_parser.hpp"
namespace deep_gemm {
struct LaunchArgs {
std::pair<int, int> grid_dim;
int num_threads;
int smem_size;
int cluster_dim;
bool enable_pdl;
LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true):
grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
LaunchArgs(const std::pair<int, int>& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true):
grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {}
};
class KernelRuntime final {
public:
static std::filesystem::path cuda_home;
LibraryHandle library;
KernelHandle kernel;
explicit KernelRuntime(const std::filesystem::path& dir_path) {
// Check `prepare_init`
DG_HOST_ASSERT(not cuda_home.empty());
// NOLINT(*-pro-type-member-init)
const auto cuobjdump_path = cuda_home / "bin" / "cuobjdump";
const auto cubin_path = dir_path / "kernel.cubin";
if (get_env<int>("DG_JIT_DEBUG"))
printf("Loading CUBIN: %s\n", cubin_path.c_str());
// Record start time
std::chrono::high_resolution_clock::time_point start_time;
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_JIT_PRINT_LOAD_TIME"))
start_time = std::chrono::high_resolution_clock::now();
#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS
// Load from the library
kernel = load_kernel(cubin_path, {}, &library);
#else
// Find the only symbol
// TODO: use kernel enumeration for newer drivers
const std::vector<std::string> illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"};
const auto [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str()));
DG_HOST_ASSERT(exit_code == 0);
std::istringstream iss(symbols);
std::vector<std::string> symbol_names;
for (std::string line; std::getline(iss, line); ) {
if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and
std::none_of(illegal_names.begin(), illegal_names.end(),
[&](const auto name) { return line.find(name) != std::string::npos; })) {
const auto last_space = line.rfind(' ');
symbol_names.push_back(line.substr(last_space + 1));
}
}
// Print symbols
if (symbol_names.size() != 1 or get_env<int>("DG_JIT_DEBUG")) {
printf("Symbols: ");
printf(" > CUBIN: %s\n", cubin_path.c_str());
printf(" > Raw symbols: %s\n", symbols.c_str());
printf(" > Parsed symbols:\n");
for (const auto& symbol: symbol_names)
printf(" > %s, ", symbol.c_str());
}
DG_HOST_ASSERT(symbol_names.size() == 1);
// Load from the library
kernel = load_kernel(cubin_path, symbol_names[0], &library);
#endif
// Print load time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_JIT_PRINT_LOAD_TIME")) {
std::chrono::duration<double, std::milli> load_time = std::chrono::high_resolution_clock::now() - start_time;
printf("Load time (%s): %.2lf ms\n", dir_path.c_str(), load_time.count());
}
}
static void prepare_init(const std::string& cuda_home_path_by_python) {
cuda_home = cuda_home_path_by_python;
}
static bool check_validity(const std::filesystem::path& dir_path) {
if (not std::filesystem::exists(dir_path))
return false;
// NOTES: if the directory exists, `kernel.cu` and `kernel.cubin` must both exist,
// because the directory is created atomically via rename
if (not std::filesystem::exists(dir_path / "kernel.cu") or
not std::filesystem::exists(dir_path / "kernel.cubin")) {
printf("Corrupted JIT cache directory (missing kernel.cu or kernel.cubin): %s, "
"please run `rm -rf %s` and restart your task.\n",
dir_path.c_str(), dir_path.c_str());
DG_HOST_ASSERT(false and "Corrupted JIT cache directory");
}
return true;
}
~KernelRuntime() noexcept(false) {
unload_library(library);
}
};
DG_DECLARE_STATIC_VAR_IN_CLASS(KernelRuntime, cuda_home);
template <typename Derived>
class LaunchRuntime {
public:
template <typename Args>
static std::string generate(const Args& args) {
auto code = Derived::generate_impl(args);
// NOTES: we require that `generate_impl`'s includes never change
static std::string include_hash;
if (include_hash.empty())
include_hash = include_parser->get_hash_value(code);
// TODO: optimize string concat performance
code = fmt::format("// Includes' hash value: {}\n{}", include_hash, code);
if (get_env<int>("DG_JIT_DEBUG"))
printf("Generated kernel code:\n%s\n", code.c_str());
return code;
}
template <typename Args>
static void launch(const std::shared_ptr<KernelRuntime>& kernel_runtime, const Args& args) {
const auto kernel = kernel_runtime->kernel;
const auto stream = at::cuda::getCurrentCUDAStream();
LaunchArgs launch_args = args.launch_args;
// Allow runtime override from Python.
// NOTES: the default is enabled.
launch_args.enable_pdl = device_runtime->get_pdl();
const dim3 grid_dim = {static_cast<unsigned>(launch_args.grid_dim.first),
static_cast<unsigned>(launch_args.grid_dim.second),
1};
const dim3 block_dim = {static_cast<unsigned>(launch_args.num_threads), 1, 1};
auto config = construct_launch_config(kernel, stream, launch_args.smem_size,
grid_dim, block_dim, launch_args.cluster_dim, launch_args.enable_pdl);
// Launch in the derived class
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, pdl: %d, stream: %ld\n",
launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads,
launch_args.smem_size, launch_args.cluster_dim, launch_args.enable_pdl, stream.id());
}
Derived::launch_impl(kernel, config, args);
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,54 @@
#pragma once
#include <unordered_set>
#include <deep_gemm/common/types.cuh>
#include "config.hpp"
#include "runtime.hpp"
#include "../../utils/layout.hpp"
#include "../../utils/system.hpp"
namespace deep_gemm {
template <typename ArchSpec>
static GemmConfig get_best_config(const GemmDesc& desc) {
desc.check_validity();
// Choose the best layout
const auto layout_candidates = ArchSpec::get_layout_candidates(desc);
DG_HOST_ASSERT(not layout_candidates.empty());
auto layout = layout_candidates[0];
auto layout_info = ArchSpec::get_layout_info(desc, layout);
for (int i = 1; i < static_cast<int>(layout_candidates.size()); ++ i) {
const auto candidate_info = ArchSpec::get_layout_info(desc, layout_candidates[i]);
if (ArchSpec::compare(candidate_info, layout_info))
layout = layout_candidates[i], layout_info = candidate_info;
}
// Infer other configs
const auto storage_config = ArchSpec::get_storage_config(desc, layout);
const auto pipeline_config = ArchSpec::get_pipeline_config(desc, layout, storage_config);
const auto launch_config = ArchSpec::get_launch_config(desc, layout);
const auto gemm_config = GemmConfig {
.layout = layout,
.storage_config = storage_config,
.pipeline_config = pipeline_config,
.launch_config = launch_config
};
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
std::stringstream ss;
ss << desc;
const auto key = ss.str();
static std::unordered_set<std::string> printed;
if (printed.count(key) == 0) {
std::cout << desc << ": " << gemm_config << ", " << layout_info << std::endl;
printed.insert(key);
}
}
return gemm_config;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,171 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
#include <c10/core/ScalarType.h>
#include <deep_gemm/common/types.cuh>
#include "../../utils/math.hpp"
namespace deep_gemm {
/// GEMM descriptors
struct GemmDesc {
GemmType gemm_type;
KernelType kernel_type;
int m, n, k, num_groups;
at::ScalarType a_dtype, b_dtype, cd_dtype;
cute::UMMA::Major major_a;
cute::UMMA::Major major_b;
bool with_accumulation;
// Requirements from users
int num_sms, tc_util;
std::string compiled_dims;
// Shape for heuristic generation
int expected_m = 0, expected_n = 0, expected_k = 0, expected_num_groups = 0;
int get_expected_m() const { return expected_m > 0 ? expected_m : m; }
int get_expected_n() const { return expected_n > 0 ? expected_n : n; }
int get_expected_k() const { return expected_k > 0 ? expected_k : k; }
int get_expected_num_groups() const { return expected_num_groups > 0 ? expected_num_groups : num_groups; }
MmaKind get_mma_kind() const {
return a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4;
}
void check_validity() const {
if (get_mma_kind() == MmaKind::BF16) {
DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16);
} else {
DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4);
DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4);
}
DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat);
DG_HOST_ASSERT(num_sms % 2 == 0);
}
friend std::ostream& operator << (std::ostream& os, const GemmDesc& desc) {
MmaKind mma_kind = desc.get_mma_kind();
os << "GemmDesc(gemm_type=" << static_cast<int>(desc.gemm_type)
<< ", kernel_type=" << static_cast<int>(desc.kernel_type)
<< ", m=" << desc.m << ", n=" << desc.n << ", k=" << desc.k
<< ", num_groups=" << desc.num_groups
<< ", major_a=" << static_cast<int>(desc.major_a)
<< ", major_b=" << static_cast<int>(desc.major_b)
<< ", mma_kind=" << static_cast<int>(mma_kind)
<< ", a_dtype=" << c10::toString(desc.a_dtype)
<< ", b_dtype=" << c10::toString(desc.b_dtype)
<< ", cd_dtype=" << c10::toString(desc.cd_dtype)
<< ", with_accumulation=" << static_cast<int>(desc.with_accumulation)
<< ", num_sms=" << desc.num_sms
<< ", tc_util=" << desc.tc_util
<< ", compiled_dims=" << desc.compiled_dims
<< ", expected_m=" << desc.expected_m
<< ", expected_n=" << desc.expected_n
<< ", expected_k=" << desc.expected_k
<< ", expected_num_groups=" << desc.expected_num_groups << ")";
return os;
}
};
/// GEMM configs
struct Layout {
int swap_ab;
int block_m, block_n, block_k;
int cluster_m, cluster_n;
int get_cluster_size() const {
return cluster_m * cluster_n;
}
friend std::ostream& operator << (std::ostream& os, const Layout& layout) {
os << "Layout(swap_ab=" << layout.swap_ab
<< ", block_m=" << layout.block_m << ", block_n=" << layout.block_n << ", block_k=" << layout.block_k
<< ", cluster_m=" << layout.cluster_m << ", cluster_n=" << layout.cluster_n << ")";
return os;
}
};
struct StorageConfig {
int load_block_m, load_block_n;
int store_block_m, store_block_n;
int swizzle_a_mode, swizzle_b_mode;
int swizzle_cd_mode;
friend std::ostream& operator << (std::ostream& os, const StorageConfig& config) {
os << "StorageConfig("
<< "load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n
<< ", store_block_m=" << config.store_block_m << ", store_block_n=" << config.store_block_n
<< ", swizzle_a_mode=" << config.swizzle_a_mode << ", swizzle_b_mode=" << config.swizzle_b_mode
<< ", swizzle_cd_mode=" << config.swizzle_cd_mode << ")";
return os;
}
};
struct PipelineConfig {
int smem_size;
int num_stages;
friend std::ostream& operator << (std::ostream& os, const PipelineConfig& config) {
os << "PipelineConfig("
<< "smem_size=" << config.smem_size
<< ", num_stages=" << config.num_stages << ")";
return os;
}
};
struct LaunchConfig {
int num_sms;
int num_sms_per_cluster;
int num_threads;
int num_tma_threads;
int num_math_threads;
int num_non_epilogue_threads;
int num_epilogue_threads;
friend std::ostream& operator << (std::ostream& os, const LaunchConfig& config) {
os << "LaunchConfig("
<< "num_sms=" << config.num_sms << ", num_sms_per_cluster=" << config.num_sms_per_cluster
<< ", num_threads=" << config.num_threads
<< ", num_tma_threads=" << config.num_tma_threads << ", num_math_threads=" << config.num_math_threads
<< ", num_non_epilogue_threads=" << config.num_non_epilogue_threads
<< ", num_epilogue_threads=" << config.num_epilogue_threads << ")";
return os;
}
};
struct GemmConfig {
Layout layout;
StorageConfig storage_config;
PipelineConfig pipeline_config;
LaunchConfig launch_config;
friend std::ostream& operator << (std::ostream& os, const GemmConfig& config) {
os << "GemmConfig("
<< "layout=" << config.layout
<< ", storage_config=" << config.storage_config
<< ", pipeline_config=" << config.pipeline_config
<< ", launch_config=" << config.launch_config << ")";
return os;
}
};
/// Config comparators
struct LayoutInfo {
int num_waves;
int last_wave_util;
int64_t num_cycles;
Layout layout;
friend std::ostream& operator << (std::ostream& os, const LayoutInfo& config) {
os << "LayoutInfo("
<< "num_waves=" << config.num_waves
<< ", last_wave_util=" << config.last_wave_util
<< ", num_cycles=" << config.num_cycles << ")";
return os;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,240 @@
#pragma once
#include <algorithm>
#include <unordered_set>
#include <deep_gemm/layout/mega_moe.cuh>
#include "../../utils/exception.hpp"
#include "../../utils/math.hpp"
#include "../../utils/system.hpp"
#include "sm100.hpp"
namespace deep_gemm {
struct MegaMoEConfig {
// Block tiling
int block_m, block_n, block_k;
int load_block_m, load_block_n;
int store_block_m;
// SF block sizes (UTCCP 128-aligned)
int sf_block_m, sf_block_n;
// Pool capacity and SF-padded token count
int num_max_pool_tokens;
int num_padded_sf_pool_tokens;
// Swizzle modes for TMA descriptors
int swizzle_acts_mode, swizzle_weights_mode;
// Number of experts to process per wave
int num_experts_per_wave;
// Pipeline stages and shared memory
int num_stages, smem_size;
// Thread layout
int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads;
friend std::ostream& operator << (std::ostream& os, const MegaMoEConfig& config) {
os << "MegaMoEConfig("
<< "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k
<< ", load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n
<< ", store_block_m=" << config.store_block_m
<< ", sf_block_m=" << config.sf_block_m << ", sf_block_n=" << config.sf_block_n
<< ", num_max_pool_tokens=" << config.num_max_pool_tokens
<< ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens
<< ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode
<< ", num_experts_per_wave=" << config.num_experts_per_wave
<< ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size
<< ", num_dispatch_threads=" << config.num_dispatch_threads
<< ", num_non_epilogue_threads=" << config.num_non_epilogue_threads
<< ", num_epilogue_threads=" << config.num_epilogue_threads << ")";
return os;
}
};
static std::tuple<int, int, int, int> get_block_config_for_mega_moe(
const int& num_ranks, const int& num_experts,
const int& num_max_tokens_per_rank, const int& num_topk,
const int& num_tokens) {
const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple<int, int, int, int> {
float num_expected_tokens_per_expert = static_cast<float>(num_tokens) * num_ranks * num_topk / num_experts;
if (num_expected_tokens_per_expert <= 8.5) {
// Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m
return {2, 16, 8, 2};
} else if (num_expected_tokens_per_expert <= 16.5) {
// Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128
return {2, 32, 16, 2};
} else if (num_expected_tokens_per_expert <= 32.5) {
// Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256
return {2, 64, 32, 1};
} else if (num_expected_tokens_per_expert <= 64.5) {
// Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512
return {2, 96, 16, 2};
} else if (num_expected_tokens_per_expert <= 96.5) {
// Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128
return {2, 128, 32, 2};
} else {
// Prefill, or large EP decoding
return {2, 192, 32, 2};
}
}();
// Check whether our `block_m` lies in `kCandidateBlockM`
DG_HOST_ASSERT(std::any_of(
layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs,
[=](const auto& candidate) { return candidate == block_m; })
);
// Return configs
return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128};
}
static int get_num_experts_per_wave_for_mega_moe(
const int& num_experts_per_rank, const int& num_tokens, const int& num_topk,
const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) {
float expected_tokens_per_expert = static_cast<float>(num_tokens) * num_topk / num_experts_per_rank;
if (expected_tokens_per_expert < 1) {
// Most experts don't have tokens, calculate all experts at once
return num_experts_per_rank;
}
// Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens
constexpr int kImbalanceFactor = 2;
// Count L1 blocks per expert assuming tokens are evenly spread across experts
const int num_m_blocks = ceil_div(static_cast<int>(std::ceil(expected_tokens_per_expert)), block_m);
const int num_n_blocks = (2 * intermediate_hidden) / block_n;
const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks;
// Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy
int num_experts_per_wave = num_l1_blocks_per_expert > 0
? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1;
num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank);
// Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count
while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0)
++ num_experts_per_wave;
return num_experts_per_wave;
}
static std::pair<int, int> get_pipeline_config_for_mega_moe(
const int& smem_capacity,
const int& num_experts, const int& hidden,
const int& block_m, const int& block_n, const int& block_k, const int& store_block_m,
const int& sf_block_m, const int& sf_block_n,
const int& num_dispatch_warps, const int& num_epilogue_warps) {
constexpr int kSmemAlignment = 1024;
constexpr int kNumEpilogueStages = 2;
constexpr int kNumTMAStoreStages = 2;
// Always multicast on A
const int load_block_m = block_m / 2;
// Dispatch region
const int smem_expert_count_size = align(
num_experts * static_cast<int>(sizeof(uint32_t)), kSmemAlignment);
const int smem_send_buffers_size = align(
static_cast<int>(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()),
kSmemAlignment);
const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size;
// C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage)
const auto num_epilogue_warpgroups = num_epilogue_warps / 4;
const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages;
const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast<int>(sizeof(nv_bfloat16));
const int smem_cd = std::max(smem_cd_l1, smem_cd_l2);
// Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp)
const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8;
// Amax reduction
const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast<int>(sizeof(float));
// Tensor memory pointer
const int smem_tmem_ptr = 4;
// SF is aligned to UTCCP 128-element granularity
const int smem_sfa_per_stage = sf_block_m * 4;
const int smem_sfb_per_stage = sf_block_n * 4;
// Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers
const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8;
// Fixed total
const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr;
// Select maximum num_stages
const int num_stages = (smem_capacity - smem_fixed) / smem_per_stage;
DG_HOST_ASSERT(num_stages >= 2);
return {num_stages, smem_fixed + num_stages * smem_per_stage};
}
static MegaMoEConfig get_mega_moe_config(
const int& num_ranks, const int& num_experts, const int& num_experts_per_rank,
const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const int& num_padded_sf_pool_tokens) {
// Block config
const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] =
get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens);
const int block_n = 128;
const int block_k = 128;
const int load_block_m = block_m / 2;
const int load_block_n = block_n;
const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4);
const int num_max_pool_tokens = layout::get_num_max_pool_tokens(
num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
// NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle
const int swizzle_acts_mode = 128;
const int swizzle_weights_mode = 128;
// Waves
const int num_sms = device_runtime->get_num_sms();
const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe(
num_experts_per_rank, num_tokens, num_topk,
intermediate_hidden, block_m, block_n, num_sms);
// Thread layout
const int num_dispatch_threads = 128;
const int num_non_epilogue_threads = 128;
// Pipeline
const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe(
SM100ArchSpec::smem_capacity,
num_experts, hidden,
block_m, block_n, block_k, store_block_m,
sf_block_m, sf_block_n,
num_dispatch_threads / 32, num_epilogue_threads / 32);
const auto config = MegaMoEConfig {
block_m, block_n, block_k,
load_block_m, load_block_n, store_block_m,
sf_block_m, sf_block_n,
num_max_pool_tokens, num_padded_sf_pool_tokens,
swizzle_acts_mode, swizzle_weights_mode,
num_experts_per_wave,
num_stages, smem_size,
num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads
};
// Print configs for the first time
if (get_env<int>("DG_JIT_DEBUG") or get_env<int>("DG_PRINT_CONFIGS")) {
const auto key = fmt::format(
"MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})",
num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk);
static std::unordered_set<std::string> printed;
if (printed.count(key) == 0) {
std::cout << key << ": " << config << std::endl;
printed.insert(key);
}
}
return config;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,62 @@
#pragma once
#include "../../jit/device_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/lazy_init.hpp"
namespace deep_gemm {
class HeuristicsRuntime {
static constexpr int kLegacyMKAlignmentForContiguousLayout = 128;
bool ignore_compile_dims = false;
int block_m_multiple_of = 1;
int block_n_multiple_of = 1;
int mk_alignment_for_contiguous_layout = kLegacyMKAlignmentForContiguousLayout;
public:
void set_ignore_compile_dims(const bool& new_value) {
ignore_compile_dims = new_value;
}
bool get_ignore_compile_dims() const {
return ignore_compile_dims;
}
void set_block_size_multiple_of(const int& new_block_m_multiple_of, const int& new_block_n_multiple_of) {
block_m_multiple_of = new_block_m_multiple_of;
block_n_multiple_of = new_block_n_multiple_of;
}
int get_block_m_multiple_of() const {
return block_m_multiple_of;
}
int get_block_n_multiple_of() const {
return block_n_multiple_of;
}
void set_mk_alignment_for_contiguous_layout(const int& new_value) {
mk_alignment_for_contiguous_layout = new_value;
}
int get_mk_alignment_for_contiguous_layout() const {
return mk_alignment_for_contiguous_layout;
}
static int get_theoretical_mk_alignment_for_contiguous_layout(const std::optional<int>& expected_m) {
if (device_runtime->get_arch_major() != 10)
return kLegacyMKAlignmentForContiguousLayout;
int block_m = 240, mma_step = 16;
if (expected_m.has_value()) {
// Reduce `block_m` while ensuring it covers `m`
for (; block_m > 32 and block_m - mma_step >= expected_m.value(); block_m -= mma_step);
}
return block_m;
}
};
static auto heuristics_runtime = LazyInit<HeuristicsRuntime>([](){ return std::make_shared<HeuristicsRuntime>(); });
} // namespace deep_gemm

View File

@@ -0,0 +1,269 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.cuh>
#include "common.hpp"
#include "runtime.hpp"
#include "utils.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
struct SM100ArchSpec {
static constexpr int smem_capacity = 232448;
static std::pair<int, int> get_sf_uttcp_aligned_block_sizes(
const int& block_m, const int& block_n, const MmaKind& mma_kind) {
constexpr int num_utccp_aligned_elems = 128;
switch (mma_kind) {
case MmaKind::BF16: return {0, 0};
case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)};
default: DG_HOST_UNREACHABLE("Unknown dtype");
}
}
static std::vector<Layout> get_layout_candidates(const GemmDesc& desc) {
// Block K is always in a fixed manner
const int block_k = 128 / get_element_size(desc.get_mma_kind());
// Always enable swap A/B (and multicasting if possible) for m-grouped GEMMs
if (desc.gemm_type == GemmType::MGroupedContiguous or
desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout or
desc.gemm_type == GemmType::MGroupedMasked) {
const bool swap_ab = true;
const auto block_n = 128;
const auto block_m = heuristics_runtime->get_mk_alignment_for_contiguous_layout();
const auto cluster_m = 1;
const auto cluster_n = ceil_div(desc.n, block_n) % 2 == 0 and desc.num_sms % 2 == 0 ? 2 : 1;
const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n};
std::vector<Layout> candidates = {layout};
return candidates;
}
// Enumerate all candidates
std::vector<Layout> candidates;
for (int swap_ab = 0; swap_ab < 2; ++ swap_ab) {
// Block M/N candidates
std::vector<int> block_m_candidates;
std::vector<int> block_n_candidates;
if (swap_ab) {
int step = std::lcm(16, heuristics_runtime->get_block_m_multiple_of());
int end = 256;
for (int i = step; i <= end; i += step)
block_m_candidates.push_back(i);
// TODO: consider other block N
block_n_candidates = {128};
} else {
// NOTES: smaller block M can avoid TMA L2 OOB bound
// TODO: consider block M = 256
if (desc.m <= 32) block_m_candidates = {32};
else if (desc.m <= 64) block_m_candidates = {64};
else block_m_candidates = {128};
// Small block size for small shape
if (16 % heuristics_runtime->get_block_n_multiple_of() == 0)
block_n_candidates.push_back(16);
int step = std::lcm(32, heuristics_runtime->get_block_n_multiple_of());
// For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck
int end = desc.k <= 256 ? 128 : 256;
for (int i = step; i <= end; i += step)
block_n_candidates.push_back(i);
}
for (int cluster_m = 1; cluster_m <= 2; ++ cluster_m) {
// After swapping, layout A/D can only do on cluster N
if (swap_ab == 1 and cluster_m > 1)
continue;
for (int cluster_n = 1; cluster_n <= 2; ++ cluster_n) {
// We only support cluster 2
if (cluster_m * cluster_n > 2)
continue;
// Only support layout A/D
if (swap_ab == 0 and cluster_n > 1)
continue;
// SM count must be divisible
if (desc.num_sms % (cluster_m * cluster_n) != 0)
continue;
for (int block_m: block_m_candidates) {
// Ensure large swizzle sizes (32B swizzle yields poor performance)
const auto swizzle_a_requirement = desc.a_dtype == kPackedFP4 ? 128 : 64;
// Enforce swizzle alignment for MN major; otherwise check base MMA shape
const auto load_block_m_requirement = desc.major_a == cute::UMMA::Major::MN ? swizzle_a_requirement : 8;
if ((block_m / cluster_n) % load_block_m_requirement != 0)
continue;
// Shape must be divisible for multicast
if (ceil_div(desc.m, block_m) % cluster_m != 0)
continue;
for (int block_n: block_n_candidates) {
// Ensure large swizzle sizes (32B swizzle yields poor performance)
const auto swizzle_b_requirement = desc.b_dtype == kPackedFP4 ? 128 : 64;
// Enforce swizzle alignment for MN major; otherwise check base MMA shape
const auto load_block_n_requirement = desc.major_b == cute::UMMA::Major::MN ? swizzle_b_requirement : 8;
if ((block_n / cluster_m) % load_block_n_requirement != 0)
continue;
// Shape must be divisible for multicast
if (ceil_div(desc.n, block_n) % cluster_n != 0)
continue;
// SwapAB requires block N is layout A/D' UMMA M
constexpr int layout_ad_m = 128;
if (swap_ab and block_n != layout_ad_m)
continue;
// Check tensor memory capacity
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, desc.get_mma_kind());
const auto tmem_sf_cols = desc.get_mma_kind() == MmaKind::MXFP8FP4 ? sf_block_m / 32 + sf_block_n / 32 : 0;
const auto umma_n = swap_ab ? block_m : block_n;
if (2 * umma_n + tmem_sf_cols > 512)
continue;
const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n};
// When neither A nor B is MN major, 128B swizzle is always feasible
if (desc.major_a == cute::UMMA::Major::K or desc.major_b == cute::UMMA::Major::K) {
const auto storage_config = get_storage_config(desc, layout);
if (storage_config.swizzle_a_mode != 128 or storage_config.swizzle_b_mode != 128)
continue;
}
candidates.push_back(layout);
}
}
}
}
}
DG_HOST_ASSERT(not candidates.empty());
return candidates;
}
static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) {
constexpr int layout_ad_m = 128;
constexpr int umma_step_n = 16;
// Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms)
const auto load_block_m = layout.block_m / layout.cluster_n;
const auto load_block_n = layout.block_n / layout.cluster_m;
const auto store_block_m = layout.swap_ab ? umma_step_n : std::min(layout_ad_m, layout.block_m);
const auto store_block_n = layout.block_n;
// Decide swizzling by the inner dim
// TODO: support FP4 sub-byte
const auto swizzle_mode_a = get_swizzle_mode(
desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype));
const auto swizzle_mode_b = get_swizzle_mode(
desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype));
const auto swizzle_mode_cd = get_swizzle_mode(
store_block_n, c10::elementSize(desc.cd_dtype));
return {
load_block_m, load_block_n,
store_block_m, store_block_n,
swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd
};
}
static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) {
constexpr int kNumMaxStages = 32;
// C/D for TMA stores
const int smem_cd = layout.swap_ab ? storage_config.store_block_m * storage_config.store_block_n * c10::elementSize(desc.cd_dtype) * 2
: storage_config.store_block_m * storage_config.swizzle_cd_mode * 2;
// TODO: remove SF barriers for BF16 GEMMs
// TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers
// NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages
// NOTES: the last barrier is for tensor core utilization control
const int smem_barriers = kNumMaxStages * 8 * 3 + 2 * 8 * 2 + 8;
// Tensor memory pointer
const int smem_tmem_ptr = 4;
// Calculate A/B per stages
// TODO: consider FP4
const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype);
const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype);
// Calculate SF A/B per stages
int smem_sfa_per_stage = 0;
int smem_sfb_per_stage = 0;
if (desc.kernel_type == KernelType::Kernel1D1D) {
const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(
layout.block_m, layout.block_n, desc.get_mma_kind());
smem_sfa_per_stage = sf_block_m * 4;
smem_sfb_per_stage = sf_block_n * 4;
}
// Calculate stages
int smem_extra = smem_cd + smem_barriers + smem_tmem_ptr;
int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage;
int num_stages = std::min(
(smem_capacity - smem_extra) / smem_per_stage,
kNumMaxStages);
return {
smem_extra + num_stages * smem_per_stage,
num_stages
};
}
static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) {
return {
desc.num_sms,
layout.get_cluster_size(),
256,
32, 128, 128, 128
};
}
static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) {
const auto num_blocks =
ceil_div(desc.get_expected_m(), layout.block_m) *
ceil_div(desc.get_expected_n(), layout.block_n) *
desc.get_expected_num_groups();
const auto num_waves = ceil_div(num_blocks, desc.num_sms);
const auto num_last_blocks = num_blocks % desc.num_sms;
const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks;
// TODO: calculate expected cycles
return {num_waves, last_wave_util, 0, layout};
}
// A regular comparator
static bool compare(const LayoutInfo& a, const LayoutInfo& b) {
// Single wave is always better
if ((a.num_waves == 1 or b.num_waves == 1) and a.num_waves != b.num_waves)
return a.num_waves < b.num_waves;
// Doing multicast is better
if (a.layout.get_cluster_size() != b.layout.get_cluster_size())
return a.layout.get_cluster_size() > b.layout.get_cluster_size();
// Smaller number of waves is better
if (a.num_waves != b.num_waves)
return a.num_waves < b.num_waves;
// Larger last wave utilization is better
if (a.last_wave_util != b.last_wave_util)
return a.last_wave_util > b.last_wave_util;
// More stages is better
// Same block M, smaller block N is better
// Same block N, smaller block M is better
if (a.layout.block_m + a.layout.block_n != b.layout.block_m + b.layout.block_n)
return a.layout.block_m + a.layout.block_n < b.layout.block_m + b.layout.block_n;
// Less shared memory C/D, more stages is better
return a.layout.block_m * a.layout.block_n < b.layout.block_m * b.layout.block_n;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,246 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.cuh>
#include "common.hpp"
#include "utils.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
struct SM90ArchSpec {
static constexpr int smem_capacity = 232448;
static std::vector<Layout> get_layout_candidates(const GemmDesc& desc) {
// Block M candidates
std::vector<int> block_m_candidates;
if (desc.gemm_type == GemmType::Normal or
desc.gemm_type == GemmType::Batched or
desc.gemm_type == GemmType::KGroupedContiguous) {
// TODO: check 256's performance
block_m_candidates = {64, 128};
// NOTES: smaller block M can avoid TMA L2 OOB bound
if (desc.m <= 16) block_m_candidates.push_back(16);
if (desc.m <= 32) block_m_candidates.push_back(32);
// BF16 output GEMM supports 256
if (desc.cd_dtype != torch::kFloat)
block_m_candidates.push_back(256);
} else if (desc.gemm_type == GemmType::MGroupedContiguous or
desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) {
block_m_candidates = std::vector{heuristics_runtime->get_mk_alignment_for_contiguous_layout()};
} else if (desc.gemm_type == GemmType::MGroupedMasked) {
block_m_candidates = {64, 128};
}
// Block N candidates
std::vector<int> block_n_candidates;
int step = std::lcm(16, heuristics_runtime->get_block_n_multiple_of());
int start = step;
// Avoid bank conflicts for 1D1D kernel FP32 output
if (desc.kernel_type == KernelType::Kernel1D1D and desc.cd_dtype == torch::kFloat) {
DG_HOST_ASSERT(desc.major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(desc.major_b == cute::UMMA::Major::K);
start = 24;
block_n_candidates.push_back(16);
}
// Register spills
int end = 256;
if (desc.kernel_type == KernelType::Kernel1D2D)
end = 192;
if (desc.kernel_type == KernelType::Kernel1D1D)
end = 160;
// Enumerate
for (int i = start; i <= end; i += step)
block_n_candidates.push_back(i);
// Block K is always in a fixed manner
const int block_k = 128 / get_element_size(desc.get_mma_kind());
// Disable multicast for performance
const bool disable_multicast =
// The number of k-groups is large (a heuristic)
(desc.gemm_type == GemmType::KGroupedContiguous and desc.num_groups > 4) or
// Not supported
(desc.gemm_type == GemmType::Batched);
// Enumerate all candidates
std::vector<Layout> candidates;
for (int cluster_m = 1; cluster_m <= (disable_multicast ? 1 : 2); ++ cluster_m) {
for (int cluster_n = 1; cluster_n <= (disable_multicast ? 1 : 2); ++ cluster_n) {
// We only support cluster 2
if (cluster_m * cluster_n > 2)
continue;
// SM count must be divisible
if (desc.num_sms % (cluster_m * cluster_n) != 0)
continue;
for (int block_m: block_m_candidates) {
for (int block_n: block_n_candidates) {
// 1D2D kernel unroll requirement
if (desc.kernel_type == KernelType::Kernel1D2D and block_n > block_k and (block_n % (block_n - block_k) != 0 and block_k % (block_n - block_k) != 0))
continue;
// Multicast legality for masked layout
// TODO: add some comments about it
if ((desc.gemm_type == GemmType::MGroupedMasked or desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) and
ceil_div(desc.n, block_n) % (cluster_m * cluster_n) != 0)
continue;
// The block sizes cannot be too large (for enough registers), so at least one dim less than 128
if (block_m > 128 and block_n > 128)
continue;
// Calculate swizzling
const auto layout = Layout{0, block_m, block_n, block_k, cluster_m, cluster_n};
const auto storage_config = get_storage_config(desc, layout);
// Make sure swizzling is large enough (32B's performance is low)
if (storage_config.swizzle_a_mode % 64 != 0 or storage_config.swizzle_b_mode % 64 != 0)
continue;
// To hide TMA latency, the stage count should be at least 3; for small matrices, at least 4
int num_stages = get_pipeline_config(desc, layout, storage_config).num_stages;
if (num_stages < 3 or (block_m * block_n < 128 * 192 and num_stages < 4))
continue;
candidates.push_back(layout);
}
}
}
}
DG_HOST_ASSERT(not candidates.empty());
return candidates;
}
static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) {
constexpr int wgmma_m = 64;
// Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms)
// TODO: support swap AB
DG_HOST_ASSERT(layout.swap_ab == 0);
const auto load_block_m = layout.block_m;
const auto load_block_n = layout.block_n;
// 1D1D kernel will do single warp-group stores
const auto store_block_m = desc.kernel_type == KernelType::Kernel1D1D ? wgmma_m : layout.block_m;
const auto store_block_n = layout.block_n;
// Decide swizzling by the inner dim
const auto swizzle_mode_a = get_swizzle_mode(
desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype));
const auto swizzle_mode_b = get_swizzle_mode(
desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype));
// We only enable swizzling for non-FP32 outputs
const auto swizzle_mode_cd = desc.cd_dtype != torch::kFloat ?
get_swizzle_mode(store_block_n, c10::elementSize(desc.cd_dtype)) : 0;
return {
load_block_m, load_block_n,
store_block_m, store_block_n,
swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd
};
}
static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) {
constexpr int kNumMaxStages = 16;
// TODO: consider swap AB
// C/D for TMA stores
// NOTES: 1024 is for TMA swizzling alignment requirement
const int smem_cd =
align(layout.block_m * layout.block_n * static_cast<int>(c10::elementSize(desc.cd_dtype)), 1024);
const int smem_barriers = kNumMaxStages * 8 * 2;
// Calculate A/B per stages
const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype);
const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype);
// Calculate SF A/B per stages
const int smem_sfa_per_stage = desc.kernel_type == KernelType::KernelNoSF ?
0 : align(layout.block_m * static_cast<int>(sizeof(float)), 128);
const int smem_sfb_per_stage = desc.kernel_type != KernelType::Kernel1D1D ?
0 : align(layout.block_n * static_cast<int>(sizeof(float)), 128);
// Extra SFB sizes for 1D2D kernels
const int use_uniform_sfb = layout.block_k % layout.block_n == 0 ? 1 : 2;
const int smem_extra_sfb = desc.kernel_type != KernelType::Kernel1D2D ?
0 : align<int>(ceil_div(desc.k, layout.block_k) * static_cast<int>(sizeof(float)) * use_uniform_sfb, 8);
// Extra tensormap for 1D1D kernels
const int smem_tensormap =
desc.gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast<int>(sizeof(CUtensorMap)) : 0;
// Calculate stages
const int smem_extra = smem_cd + smem_barriers + smem_extra_sfb + smem_tensormap;
const int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage;
const int num_stages = std::min(
(smem_capacity - smem_extra) / smem_per_stage,
kNumMaxStages);
return {
smem_extra + num_stages * smem_per_stage,
num_stages
};
}
static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) {
const int num_tma_threads = 128;
const int num_math_threads = layout.block_m <= 64 ? 128 : 256;
return {
desc.num_sms,
layout.get_cluster_size(),
num_tma_threads + num_math_threads,
num_tma_threads, num_math_threads,
0, 0 // Meaningless for SM90
};
}
static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) {
const auto num_blocks =
ceil_div(desc.get_expected_m(), layout.block_m) *
ceil_div(desc.get_expected_n(), layout.block_n) *
desc.get_expected_num_groups();
const auto num_waves = ceil_div(num_blocks, desc.num_sms);
const auto num_last_blocks = num_blocks % desc.num_sms;
const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks;
// Utils
const int l2_bandwidth_per_cycle = std::min(64. * desc.num_sms, 8e6 / (1.3e3)); // B/cycle
const int l1_bandwidth_per_cycle = 128 * desc.num_sms; // B/cycle
const int wgmma_m = 64;
const int elem_size_ab = c10::elementSize(desc.a_dtype);
const int elem_size_cd = c10::elementSize(desc.cd_dtype);
DG_HOST_ASSERT(desc.a_dtype == desc.b_dtype);
// Data movement per block
int64_t expected_k = desc.get_expected_k();
int64_t num_bytes_l2_ab = expected_k * (layout.block_m / layout.cluster_n + layout.block_n / layout.cluster_m) * elem_size_ab;
int64_t num_bytes_l1_ab = expected_k * (layout.block_m + layout.block_n) * elem_size_ab;
int64_t num_bytes_l1_tc = expected_k * (std::max(wgmma_m, layout.block_m) + layout.block_n) * elem_size_ab
+ layout.block_m * layout.block_n * elem_size_cd;
int64_t num_bytes_l1_l2_cd = layout.block_m * layout.block_n * elem_size_cd * (desc.with_accumulation ? 2 : 1);
// HBM bandwidth and total compute (Tensor/CUDA cores) are constant across configs
// We only model L1/L2 cycles as they are the primary variables between configs
int64_t num_l2_cycles = (num_bytes_l2_ab + num_bytes_l1_l2_cd) * num_blocks / l2_bandwidth_per_cycle;
int64_t num_l1_cycles = (num_bytes_l1_ab + num_bytes_l1_tc + num_bytes_l1_l2_cd) * num_blocks / l1_bandwidth_per_cycle;
float wave_efficiency = static_cast<float>(num_blocks) / (num_waves * desc.num_sms);
int64_t num_cycles = std::max(num_l1_cycles, num_l2_cycles) / wave_efficiency;
// Disable multicasting if only one wave exists
if (layout.cluster_n * layout.cluster_m > 1 and num_waves <= 1)
num_cycles = std::numeric_limits<int64_t>::max();
return {num_waves, last_wave_util, num_cycles, layout};
}
// A regular comparator
static bool compare(const LayoutInfo& a, const LayoutInfo& b) {
return a.num_cycles < b.num_cycles;
}
};
} // namespace deep_gemm

View File

@@ -0,0 +1,23 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
// Reuse some types in the JIT modules
#include <deep_gemm/common/types.cuh>
#include "common.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
template <typename size_type_t>
static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) {
// `> 0` means interleaving
// 16B actually means non-swizzling (but interleaving)
for (const int& mode: {128, 64, 32, 16}) {
if ((block_size * static_cast<int>(elem_size)) % mode == 0)
return mode;
}
DG_HOST_UNREACHABLE("Unreachable");
}
} // namespace deep_gemm

View File

@@ -0,0 +1,12 @@
#pragma once
#include <optional>
#include <string>
namespace deep_gemm {
static std::string get_default_epilogue_type(const std::optional<std::string>& epilogue_type) {
return epilogue_type.value_or("epilogue::transform::EpilogueIdentity");
}
} // namespace deep_gemm

View File

@@ -0,0 +1,267 @@
#pragma once
#include <cuda.h>
#include <torch/python.h>
#include "../heuristics/sm90.hpp"
#include "../../jit/handle.hpp"
#include "../../utils/math.hpp"
#include "../../utils/system.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
static std::pair<int, int> get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) {
return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k);
}
static int get_non_contiguous_dim(const cute::UMMA::Major& major) {
return major == cute::UMMA::Major::K ? -2 : -1;
}
static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) {
if (heuristics_runtime->get_ignore_compile_dims())
return 0;
for (const char& c: compiled_dims) {
if (name == c)
return dim;
}
return 0;
}
static std::string to_string(const cute::UMMA::Major& major) {
switch (major) {
case cute::UMMA::Major::K: return "cute::UMMA::Major::K";
case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN";
}
DG_HOST_UNREACHABLE("Unknown major");
}
static std::string to_string(const GemmType& type) {
switch (type) {
case GemmType::Normal: return "GemmType::Normal";
case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous";
case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked";
case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout";
case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous";
case GemmType::Batched: return "GemmType::Batched";
}
DG_HOST_UNREACHABLE("Unknown GEMM type");
}
static std::string to_string(const at::ScalarType& dtype) {
switch (dtype) {
case torch::kInt: return "int";
case torch::kFloat: return "float";
case torch::kBFloat16: return "cutlass::bfloat16_t";
case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t";
case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t";
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static std::string to_string(const float& v) {
if (std::isfinite(v)) {
return fmt::format(R"({:a}f)", v);
} else if (std::isinf(v)) {
return v > 0 ? "cute::numeric_limits<float>::infinity()"
: "-cute::numeric_limits<float>::infinity()";
}
DG_HOST_UNREACHABLE("NaN input is not supported");
}
static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype,
const bool& allow_tf32,
const bool& fp4_unpacked_smem) {
if (allow_tf32 and dtype == torch::kFloat)
return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32;
switch (dtype) {
case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32;
case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32;
case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8;
#if CUDA_VERSION >= 12080
case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B
: CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B;
#endif
default: DG_HOST_UNREACHABLE("Unsupported dtype");
}
}
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
#if CUDA_VERSION >= 12080
if (base != 0) {
DG_HOST_ASSERT(base == 32 and mode == 128);
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
}
#endif
DG_HOST_ASSERT(base == 0);
switch (mode) {
case 0:
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;
default: DG_HOST_UNREACHABLE("Unsupported swizzling mode");
}
}
static CUtensorMap make_tma_2d_desc(const torch::Tensor& t,
int gmem_inner_dim, int gmem_outer_dim,
int smem_inner_dim, int smem_outer_dim,
const int& gmem_outer_stride,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false,
const bool& fp4_unpacked_smem = true) {
const auto elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_inner_dim = swizzle_mode / elem_size;
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0);
// Fix FP4 packed smem
if (not fp4_unpacked_smem and swizzle_mode != 0)
smem_inner_dim = swizzle_mode * 2;
}
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[2] = {static_cast<cuuint64_t>(gmem_inner_dim), static_cast<cuuint64_t>(gmem_outer_dim)};
const cuuint32_t smem_dims[2] = {static_cast<cuuint32_t>(smem_inner_dim), static_cast<cuuint32_t>(smem_outer_dim)};
const cuuint64_t gmem_strides[1] = {static_cast<cuuint64_t>(gmem_outer_stride * elem_size), };
const cuuint32_t elem_strides[2] = {1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d, pointer: %llu\n",
gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim,
gmem_outer_stride, swizzle_mode, swizzle_base, elem_size,
reinterpret_cast<unsigned long long>(t.data_ptr()));
}
DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem),
2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_3d_desc(const torch::Tensor& t,
int gmem_dim_0, int gmem_dim_1, int gmem_dim_2,
int smem_dim_0, int smem_dim_1, int smem_dim_2,
const int& gmem_stride_0, const int& gmem_stride_1,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false,
const bool& fp4_unpacked_smem = true) {
const auto elem_size = static_cast<int>(t.element_size());
if (swizzle_mode != 0)
smem_dim_0 = swizzle_mode / elem_size;
if (t.scalar_type() == kPackedFP4) {
// Inner dim must be a multiple of 64B for .b4x16_p64
DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_dim_0 % 128 == 0);
// Fix fp4 packed smem
if (not fp4_unpacked_smem and swizzle_mode != 0)
smem_dim_0 = swizzle_mode * 2;
}
CUtensorMap tensor_map;
const cuuint64_t gmem_dims[3] = {static_cast<cuuint64_t>(gmem_dim_0), static_cast<cuuint64_t>(gmem_dim_1), static_cast<cuuint64_t>(gmem_dim_2),};
const cuuint32_t smem_dims[3] = {static_cast<cuuint32_t>(smem_dim_0), static_cast<cuuint32_t>(smem_dim_1), static_cast<cuuint32_t>(smem_dim_2)};
const cuuint64_t gmem_strides[2] = {static_cast<cuuint64_t>(gmem_stride_0 * elem_size), static_cast<cuuint64_t>(gmem_stride_1 * elem_size)};
const cuuint32_t elem_strides[3] = {1, 1, 1};
if (get_env<int>("DG_JIT_DEBUG")) {
printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n",
gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2,
gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size);
}
DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled(
&tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem),
3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides,
CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensor_map;
}
static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_m, const int& shape_k,
const int& block_m, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
if (num_groups > 1)
DG_HOST_ASSERT(major == cute::UMMA::Major::K);
const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups);
const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m);
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
const int& shape_n, const int& shape_k,
const int& block_n, const int& block_k,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n);
const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n);
// `num_groups` is always applied into the outer dimensions
return make_tma_2d_desc(t,
gmem_inner_dim, gmem_outer_dim * num_groups,
smem_inner_dim, smem_outer_dim,
outer_stride,
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_cd_desc(const torch::Tensor& t,
const int& shape_m, const int& shape_n,
const int& block_m, const int& block_n,
const int& outer_stride,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
// Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode`
// bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required
return make_tma_2d_desc(t,
shape_n, shape_m * num_groups,
block_n, block_m,
outer_stride,
swizzle_mode, swizzle_base,
allow_tf32);
}
static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
const torch::Tensor& t,
int shape_mn, int shape_k,
const int& block_mn, const int& gran_k,
const int& num_groups,
const int& swizzle_mode, const int& swizzle_base = 0,
const bool& allow_tf32 = false) {
DG_HOST_ASSERT(major == cute::UMMA::Major::MN);
// TODO: maybe swizzle SF as well
DG_HOST_ASSERT(swizzle_mode == 0);
shape_mn = get_tma_aligned_size(shape_mn, static_cast<int>(t.element_size()));
return make_tma_2d_desc(t,
shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups,
block_mn, 1,
shape_mn,
swizzle_mode, swizzle_base,
allow_tf32);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,415 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100BF16GemmRuntime final: public LaunchRuntime<SM100BF16GemmRuntime> {
public:
struct Args {
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
void* grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_cd;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_bf16_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_bf16_gemm_impl<
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{},
{}, {}, {},
{}
>);
}};
)",
to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b),
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_desc.num_groups,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms,
args.gemm_config.layout.swap_ab,
to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, to_string(args.gemm_desc.cd_dtype),
args.gemm_desc.tc_util);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_cd));
}
};
static void sm100_bf16_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_gemm", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
// Otherwise, treat the contiguous layout as a whole.
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0;
for (const auto k: ks) {
sum_k += k;
DG_HOST_ASSERT(k % 128 == 0);
}
const auto num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto max_k = *std::max_element(ks.begin(), ks.end());
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(0)), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(0)), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(1)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch kernel
const SM100BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_k_grouped_gemm", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_b,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = d, .k = r, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h,
config.layout.block_k, config.storage_config.load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.layout.block_k, config.storage_config.load_block_n, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h,
config.storage_config.store_block_n, config.storage_config.store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_b,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = r, .k = d, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h,
config.layout.block_k, config.storage_config.load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.storage_config.load_block_n, config.layout.block_k, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h,
config.storage_config.store_block_n, config.storage_config.store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM100BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code);
SM100BF16GemmRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,137 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100BmkBnkMnRuntime final: public LaunchRuntime<SM100BmkBnkMnRuntime> {
public:
struct Args {
int s, m, n, k;
int block_m, block_n, block_k;
int split_factor;
int swizzle_ab_mode, swizzle_cd_mode;
int num_stages;
int num_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_bmk_bnk_mn.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_bmn_bnk_mn_gemm_impl<
{}, {}, {},
{}, {}, {},
{},
{}, {},
{}, {}
>);
}};
)",
args.m, args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.split_factor,
args.swizzle_ab_mode, args.swizzle_cd_mode,
args.num_stages, args.num_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d));
}
};
static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
const torch::Tensor &b,
const torch::Tensor &d,
const int &s, const int &m, const int &n, const int &k) {
constexpr int block_m = 128;
constexpr int block_n = 128;
constexpr int block_k = 64;
constexpr int num_threads = 128;
DG_HOST_ASSERT(k % block_k == 0);
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast<int>(d.element_size()));
// Get best config
const int num_sms = device_runtime->get_num_sms();
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
const int num_sk_blocks = s * (k / block_k);
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
// Select best number of stages
// NOTES: we select 4 as start, as it is tested to be faster than values > 4
int num_stages = 4, smem_size = 0;
while (true) {
const int smem_cd = block_m * swizzle_cd_mode * 2;
const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int smem_barrier = num_stages * 8 * 3 + 2 * 8 * 2 + 8;
const int smem_tmem_ptr = 4;
smem_size = 0;
smem_size += smem_cd;
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
smem_size += smem_barrier;
smem_size += smem_tmem_ptr;
if (smem_size <= SM100ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("S: %d, M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
"stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n",
s, m, n, k, block_m, block_n, block_k, split_factor,
num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode);
}
const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const auto tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode);
const SM100BmkBnkMnRuntime::Args args = {
.s = s, .m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.split_factor = split_factor,
.swizzle_ab_mode = swizzle_ab_mode,
.swizzle_cd_mode = swizzle_cd_mode,
.num_stages = num_stages,
.num_threads = num_threads,
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d
};
const auto code = SM100BmkBnkMnRuntime::generate(args);
const auto runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code);
SM100BmkBnkMnRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,459 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8FP4Gemm1D1DRuntime> {
public:
struct Args {
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
// TODO: move into descriptor
const std::optional<std::string> epilogue_type;
// TODO: move into descriptor
int gran_k_a, gran_k_b;
void* grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_cd;
};
static std::string generate_impl(const Args& args) {
// TODO: rename files
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_fp4_gemm_1d1d_impl<
{}, {},
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{},
{}, {},
{}, {}, {},
{}
>);
}};
)",
to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b),
args.gran_k_a, args.gran_k_b,
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_desc.num_groups,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms,
args.gemm_config.layout.swap_ab,
to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation,
to_string(args.gemm_desc.a_dtype), to_string(args.gemm_desc.b_dtype), to_string(args.gemm_desc.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
}
};
static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const auto cd = c.value_or(d);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, 1, 0);
// Launch
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = epilogue_type,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const int& num_groups, const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
// Otherwise, treat the contiguous layout as a whole.
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, num_groups, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const int& gran_k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
const int gran_k_a = gran_k;
const int gran_k_b = gran_k;
int sum_k = 0, sum_sf_k = 0;
for (const auto k: ks) {
sum_k += k, sum_sf_k += ceil_div(k, gran_k * 4);
DG_HOST_ASSERT(k % gran_k == 0);
}
const auto num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto max_k = *std::max_element(ks.begin(), ks.end());
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM100ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(0)), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(0)), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(1)), num_groups,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * gran_k_a * 4,
config.layout.block_m, gran_k_a, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * gran_k_b * 4,
config.layout.block_n, gran_k_b, 1, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& batch_size, const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = batch_size,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(),
.compiled_dims = compiled_dims
};
const auto config = get_best_config<SM100ArchSpec>(desc);
const int load_block_m = config.storage_config.load_block_m;
const auto [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m);
const auto [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.layout.block_k, load_block_m);
const auto tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size,
inner_block_a, outer_block_a, 1,
a.stride(major_a == cute::UMMA::Major::K ? 1 : 2),
a.stride(0),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n);
const auto [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.layout.block_k, load_block_n);
const auto tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size,
inner_block_b, outer_block_b, 1,
b.stride(major_b == cute::UMMA::Major::K ? 1 : 2),
b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size,
store_block_n, store_block_m, 1,
d.stride(1), d.stride(0),
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, gran_k_a, batch_size, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, gran_k_b, batch_size, 0);
// Launch
const SM100FP8FP4Gemm1D1DRuntime::Args args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,220 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "runtime_utils.hpp"
#include <deep_gemm/layout/mega_moe.cuh>
#include <deep_gemm/layout/sym_buffer.cuh>
#include "../heuristics/mega_moe.hpp"
namespace deep_gemm {
class SM100FP8FP4MegaMoERuntime final : public LaunchRuntime<SM100FP8FP4MegaMoERuntime> {
public:
struct Args {
// Templated arguments
int num_max_tokens_per_rank;
int hidden, intermediate_hidden;
int num_experts, num_topk;
int num_ranks;
float activation_clamp;
bool fast_math;
MegaMoEConfig config;
// Runtime arguments
void* y;
int* cumulative_local_expert_recv_stats;
int num_tokens;
layout::SymBuffer<> sym_buffer_ptrs;
// Tensormap
CUtensorMap tensor_map_l1_acts;
CUtensorMap tensor_map_l1_acts_sf;
CUtensorMap tensor_map_l1_weights;
CUtensorMap tensor_map_l1_weights_sf;
CUtensorMap tensor_map_l1_output;
CUtensorMap tensor_map_l2_acts;
CUtensorMap tensor_map_l2_acts_sf;
CUtensorMap tensor_map_l2_weights;
CUtensorMap tensor_map_l2_weights_sf;
// Launch configs
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_fp4_mega_moe_impl<
{},
{}, {},
{}, {},
{},
{}, {}, {},
{},
{}, {},
{},
{},
{},
{}, {}, {},
{}, {},
{},
{}
>);
}};
)", args.num_max_tokens_per_rank,
args.hidden, args.intermediate_hidden,
args.num_experts, args.num_topk,
args.config.num_experts_per_wave,
args.config.block_m, args.config.block_n, args.config.block_k,
args.config.store_block_m,
args.config.sf_block_m, args.config.sf_block_n,
args.config.num_max_pool_tokens,
args.config.num_padded_sf_pool_tokens,
args.config.num_stages,
args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads,
args.launch_args.grid_dim.first, args.num_ranks,
to_string(args.activation_clamp),
args.fast_math ? "true" : "false");
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.y,
args.cumulative_local_expert_recv_stats,
args.num_tokens,
args.sym_buffer_ptrs,
args.tensor_map_l1_acts,
args.tensor_map_l1_acts_sf,
args.tensor_map_l1_weights,
args.tensor_map_l1_weights_sf,
args.tensor_map_l1_output,
args.tensor_map_l2_acts,
args.tensor_map_l2_acts_sf,
args.tensor_map_l2_weights,
args.tensor_map_l2_weights_sf
));
}
};
static void sm100_fp8_fp4_mega_moe(
const torch::Tensor& y,
const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf,
const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf,
const torch::Tensor& l1_weights, const torch::Tensor& l2_weights,
const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf,
const std::optional<torch::Tensor> cumulative_local_expert_recv_stats,
const std::vector<int64_t>& sym_buffer_ptrs,
const int& rank_idx, const int& num_max_tokens_per_rank,
const int& num_experts_per_rank,
const int& num_tokens, const int& num_topk,
const int& hidden, const int& intermediate_hidden,
const float& activation_clamp,
const bool& fast_math
) {
const auto num_ranks = static_cast<int>(sym_buffer_ptrs.size());
const auto num_experts = num_experts_per_rank * num_ranks;
const auto num_padded_sf_pool_tokens = static_cast<int>(l1_acts_sf.size(0));
// Heuristics
const auto config = get_mega_moe_config(
num_ranks, num_experts, num_experts_per_rank,
num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens);
// Make tensormap
constexpr int kGranK = 32;
const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts,
hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l1_acts.stride(-2)),
config.swizzle_acts_mode);
const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf,
config.num_padded_sf_pool_tokens, hidden,
config.sf_block_m, kGranK,
1, 0);
const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights,
hidden, num_experts_per_rank * intermediate_hidden * 2,
config.block_k, config.load_block_n,
static_cast<int>(l1_weights.stride(-2)),
config.swizzle_weights_mode);
const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf,
intermediate_hidden * 2, hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// NOTES: L1 output and L2 activations are essentially the same tensor.
// Post-SwiGLU output has half the N width (`BLOCK_N / 2` per input tile),
// so the swizzle mode is also halved (128 -> 64).
const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_n / 2, config.store_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode / 2);
const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts,
intermediate_hidden, config.num_max_pool_tokens,
config.block_k, config.load_block_m,
static_cast<int>(l2_acts.stride(-2)),
config.swizzle_acts_mode);
const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf,
config.num_padded_sf_pool_tokens, intermediate_hidden,
config.sf_block_m, kGranK,
1, 0);
const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights,
intermediate_hidden, num_experts_per_rank * hidden,
config.block_k, config.load_block_n,
static_cast<int>(l2_weights.stride(-2)),
config.swizzle_weights_mode);
const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf,
hidden, intermediate_hidden,
config.block_n, kGranK,
num_experts_per_rank, 0);
// Stats can be optional
int* cumulative_local_expert_recv_stats_ptr = nullptr;
if (cumulative_local_expert_recv_stats.has_value())
cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr<int>();
// Launch
const auto num_sms = device_runtime->get_num_sms();
const SM100FP8FP4MegaMoERuntime::Args args = {
.num_max_tokens_per_rank = num_max_tokens_per_rank,
.hidden = hidden, .intermediate_hidden = intermediate_hidden,
.num_experts = num_experts, .num_topk = num_topk,
.num_ranks = num_ranks,
.activation_clamp = activation_clamp,
.fast_math = fast_math,
.config = config,
.y = y.data_ptr(),
.cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr,
.num_tokens = num_tokens,
.sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx),
.tensor_map_l1_acts = tensor_map_l1_acts,
.tensor_map_l1_acts_sf = tensor_map_l1_acts_sf,
.tensor_map_l1_weights = tensor_map_l1_weights,
.tensor_map_l1_weights_sf = tensor_map_l1_weights_sf,
.tensor_map_l1_output = tensor_map_l1_output,
.tensor_map_l2_acts = tensor_map_l2_acts,
.tensor_map_l2_acts_sf = tensor_map_l2_acts_sf,
.tensor_map_l2_weights = tensor_map_l2_weights,
.tensor_map_l2_weights_sf = tensor_map_l2_weights_sf,
.launch_args = LaunchArgs(num_sms,
config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads,
config.smem_size, 2)
};
const auto code = SM100FP8FP4MegaMoERuntime::generate(args);
const auto runtime = compiler->build("sm100_fp8_fp4_mega_moe", code);
SM100FP8FP4MegaMoERuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,416 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime<SM100FP8FP4Gemm1D1DRuntime> {
public:
struct Args {
int m, n, k, num_groups;
int gran_k_a, gran_k_b;
const std::string& compiled_dims;
const std::optional<std::string>& epilogue_type;
GemmConfig gemm_config;
LaunchArgs launch_args;
void* grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_cd;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp8_gemm_1d1d_impl<
{}, {},
{}, {},
{}, {}, {},
{}, {}, {},
{},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}, {}, {},
{}
>);
}};
)",
to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b),
args.gran_k_a, args.gran_k_b,
get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims),
args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k,
args.num_groups,
args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode,
args.gemm_config.num_stages,
args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads,
args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a,
args.gemm_config.num_sms,
to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation,
to_string(args.gemm_config.a_dtype), to_string(args.gemm_config.b_dtype), to_string(args.gemm_config.cd_dtype),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout, args.m, args.n, args.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
}
};
static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Normal, KernelType::Kernel1D1D,
m, n, k, 1, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const auto& cd = c.value_or(d);
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, gran_k_a, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, gran_k_b, 1, 0);
// Launch
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = 1,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = epilogue_type,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& grouped_layout,
const int& num_groups, const int& m, const int& n, const int& k,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`.
// Otherwise, treat the contiguous layout as a whole.
const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m;
const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1;
const auto& config = get_best_config<SM100ArchSpec>(
gemm_type, KernelType::Kernel1D1D,
m_for_config, n, k, num_groups_for_config, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), 1,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, gran_k_a, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = grouped_layout.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const int& gran_k_a, const int& gran_k_b,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::MGroupedMasked, KernelType::Kernel1D1D,
expected_m, n, k, num_groups, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), false,
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(-2)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, gran_k_a, num_groups, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, gran_k_b, num_groups, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = num_groups,
.gran_k_a = gran_k_a,
.gran_k_b = gran_k_b,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0, sum_sf_k = 0;
for (const auto& k: ks) {
sum_k += k, sum_sf_k += ceil_div(k, 512);
DG_HOST_ASSERT(k % 128 == 0);
}
const auto& num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto& max_k = *std::max_element(ks.begin(), ks.end());
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::KGroupedContiguous, KernelType::Kernel1D1D,
m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
// Create tensor descriptors
const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m),
config.block_k,
static_cast<int>(a.stride(0)), 1,
config.smem_config.swizzle_a_mode);
const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n),
config.block_k,
static_cast<int>(b.stride(0)), 1,
config.smem_config.swizzle_b_mode);
const auto& tensor_map_cd = make_tma_cd_desc(d, m, n,
SM100ArchSpec::get_cd_store_block_m(config.block_m),
SM100ArchSpec::get_cd_store_block_n(config.block_n),
static_cast<int>(d.stride(1)), num_groups,
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512,
config.block_m, config.block_k, 1, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512,
config.block_n, config.block_k, 1, 0);
// Launch kernel
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = sum_k,
.num_groups = num_groups,
.gran_k_a = 128,
.gran_k_b = 128,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& batch_size, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto& config = get_best_config<SM100ArchSpec>(
GemmType::Batched, KernelType::Kernel1D1D,
m, n, k, batch_size, major_a, major_b,
a.scalar_type(), b.scalar_type(),
d.scalar_type(), c.has_value(),
device_runtime->get_num_sms());
const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m);
const auto& [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m);
const auto& [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.block_k, load_block_m);
const auto& tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size,
inner_block_a, outer_block_a, 1,
a.stride(major_a == cute::UMMA::Major::K ? 1 : 2),
a.stride(0),
config.smem_config.swizzle_a_mode);
const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n);
const auto& [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n);
const auto& [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.block_k, load_block_n);
const auto& tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size,
inner_block_b, outer_block_b, 1,
b.stride(major_b == cute::UMMA::Major::K ? 1 : 2),
b.stride(0),
config.smem_config.swizzle_b_mode);
const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m);
const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n);
const auto& tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size,
store_block_n, store_block_m, 1,
d.stride(1), d.stride(0),
config.smem_config.swizzle_cd_mode);
const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.block_m, config.block_k, batch_size, 0);
const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.block_n, config.block_k, batch_size, 0);
// Launch
const SM100FP8FP4Gemm1D1DRuntime::Args& args = {
.m = m, .n = n, .k = k,
.num_groups = batch_size,
.gran_k_a = 128,
.gran_k_b = 128,
.compiled_dims = compiled_dims,
.epilogue_type = std::nullopt,
.gemm_config = config,
.launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads,
config.smem_config.smem_size,
config.multicast_config.num_multicast),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd
};
const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args);
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,149 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM100BF16HCPrenormGemmRuntime final: public LaunchRuntime<SM100BF16HCPrenormGemmRuntime> {
public:
struct Args {
int m, n, k;
int block_m, block_n, block_k;
int num_splits;
int swizzle_cd_mode;
int num_stages;
int num_mma_threads, num_cast_and_reduce_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
float* sqr_sum;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_tf32_hc_prenorm_gemm_impl<
{}, {},
{}, {}, {},
{},
{},
{},
{}, {}
>);
}};
)",
args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.num_splits,
args.swizzle_cd_mode,
args.num_stages,
args.num_mma_threads, args.num_cast_and_reduce_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum));
}
};
static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& sqr_sum,
const int& m, const int& n, const int& k,
const int& num_splits) {
constexpr int block_m = 64;
constexpr int block_k = 64;
constexpr int num_mma_threads = 128;
constexpr int num_cast_and_reduce_threads = 128;
const int block_n = align(n, 16);
DG_HOST_ASSERT(n <= block_n);
DG_HOST_ASSERT(n <= 128 and n % 8 == 0);
DG_HOST_ASSERT(k % block_k == 0);
const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
: make_tma_3d_desc(d, n, m, num_splits,
block_n, block_m, 1,
static_cast<int>(d.stride(-2)),
static_cast<int>(d.stride(-3)),
swizzle_cd_mode);
// Calculate stages
int num_stages = 12, smem_size = 0;
while (num_stages > 0) {
const int smem_a_per_stage = block_m * block_k * static_cast<int>(sizeof(nv_bfloat16));
const int smem_b_per_stage = block_n * block_k * static_cast<int>(sizeof(float));
const int smem_cd = block_m * swizzle_cd_mode;
const int smem_barriers = (num_stages * 4 + 1) * 8;
const int smem_tmem_ptr = 4;
smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages +
smem_cd + smem_barriers + smem_tmem_ptr;
if (smem_size <= SM100ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split K: %d"
"stages: %d, shared memory: %d, swizzle CD: %d\n",
m, n, k, block_m, block_n, block_k, num_splits,
num_stages, smem_size, swizzle_cd_mode);
}
// Launch
const SM100BF16HCPrenormGemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.num_splits = num_splits,
.swizzle_cd_mode = swizzle_cd_mode,
.num_stages = num_stages,
.num_mma_threads = num_mma_threads,
.num_cast_and_reduce_threads = num_cast_and_reduce_threads,
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.sqr_sum = sqr_sum.data_ptr<float>()
};
const auto code = SM100BF16HCPrenormGemmRuntime::generate(args);
const auto runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code);
SM100BF16HCPrenormGemmRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,432 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90BF16GemmRuntime final: public LaunchRuntime<SM90BF16GemmRuntime> {
public:
struct Args {
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_cd;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_bf16_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_bf16_gemm_impl<
{}, {},
{}, {}, {},
{},
{}, {}, {},
{}, {}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)",
// TODO: add CD dtype
to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b),
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_desc.num_groups,
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_config.storage_config.swizzle_a_mode,
args.gemm_config.storage_config.swizzle_b_mode,
args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads,
// TODO: refactor with cluster M/N
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms,
to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation,
to_string(args.gemm_desc.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.grouped_layout,
args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_cd));
}
};
static void sm90_bf16_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
DG_HOST_ASSERT(k % 64 == 0);
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = 0, .expected_k = 0, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN);
int sum_k = 0;
for (const auto k: ks) {
sum_k += k;
DG_HOST_ASSERT(k % 128 == 0);
}
const auto num_groups = static_cast<int>(ks.size());
// Get config using max K for better performance
const auto max_k = *std::max_element(ks.begin(), ks.end());
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::KernelNoSF,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Create tensor descriptors
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(0)), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(0)), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(1)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch kernel
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_k_grouped_gemm", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_b,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = d, .k = r, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h,
config.layout.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
config.layout.block_k, load_block_n, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a,
const torch::Tensor& tensor_b,
const torch::Tensor& tensor_d,
const int& b, const int& h, const int& r, const int& d,
const std::string& compiled_dims = "nk") {
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::KernelNoSF,
.m = b, .n = r, .k = d, .num_groups = h,
.a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(),
.cd_dtype = tensor_d.scalar_type(),
.major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h,
config.layout.block_k, load_block_m, 1,
tensor_a.stride(0), tensor_a.stride(1),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h,
load_block_n, config.layout.block_k, 1,
tensor_b.stride(1), tensor_b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h,
store_block_n, store_block_m, 1,
tensor_d.stride(0), tensor_d.stride(1),
config.storage_config.swizzle_cd_mode);
// Launch
const SM90BF16GemmRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90BF16GemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code);
SM90BF16GemmRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,131 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90BmkBnkMnRuntime final: public LaunchRuntime<SM90BmkBnkMnRuntime> {
public:
struct Args {
int s, m, n, k;
int block_m, block_n, block_k;
int split_factor;
int num_stages;
int num_tma_threads, num_math_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
float* d;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_bmk_bnk_mn.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_bmn_bnk_mn_gemm_impl<
{}, {}, {},
{}, {}, {},
{},
{},
{}, {}
>);
}};
)",
args.m, args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.split_factor,
args.num_stages,
args.num_tma_threads, args.num_math_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.s, args.tensor_map_a, args.tensor_map_b, args.d));
}
};
static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a,
const torch::Tensor &b,
const torch::Tensor &d,
const int &s, const int &m, const int &n, const int &k) {
constexpr int block_m = 128;
constexpr int block_n = 128;
constexpr int block_k = 64;
constexpr int num_tma_threads = 128;
constexpr int num_math_threads = 256;
DG_HOST_ASSERT(k % block_k == 0);
DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0);
DG_HOST_ASSERT(static_cast<int64_t>(s) * static_cast<int64_t>(std::max(m, n)) <= std::numeric_limits<int>::max());
const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast<int>(a.element_size()));
DG_HOST_ASSERT(swizzle_ab_mode == 128);
// Get best config
const int num_sms = device_runtime->get_num_sms();
const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n);
const int num_sk_blocks = s * (k / block_k);
const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1));
// Select best number of stages
int num_stages = 4, smem_size = 0;
while (true) {
const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t);
const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t);
const int smem_barrier = num_stages * 8 * 2;
smem_size = 0;
smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages;
smem_size += smem_barrier;
if (smem_size <= SM90ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("S: %d, M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split-K factor: %d"
"stages: %d, shared memory: %d, swizzle AB: %d\n",
s, m, n, k, block_m, block_n, block_k, split_factor,
num_stages, smem_size, swizzle_ab_mode);
}
const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode);
const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode);
const SM90BmkBnkMnRuntime::Args& args = {
.s = s, .m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.split_factor = split_factor,
.num_stages = num_stages,
.num_tma_threads = num_tma_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.d = d.data_ptr<float>()
};
const auto code = SM90BmkBnkMnRuntime::generate(args);
const auto runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code);
SM90BmkBnkMnRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,229 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime<SM90FP8Gemm1D1DRuntime> {
public:
struct Args {
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
void *gmem_a_ptr;
void *gmem_b_ptr;
void *grouped_layout;
void *tensor_map_buffer;
CUtensorMap tensor_map_a_base;
CUtensorMap tensor_map_b_base;
CUtensorMap tensor_map_sfa;
CUtensorMap tensor_map_sfb;
CUtensorMap tensor_map_cd;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d1d_impl<
{}, {}, {},
{},
{}, {}, {},
{}, {},
{},
{}, {},
{}, {},
{},
{}, {}
>);
}};
)",
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_desc.num_groups,
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type),
to_string(args.gemm_desc.cd_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.gmem_a_ptr, args.gmem_b_ptr,
args.grouped_layout,
args.tensor_map_buffer,
args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a_base, args.tensor_map_b_base,
args.tensor_map_sfa, args.tensor_map_sfb,
args.tensor_map_cd));
}
};
static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k, k, 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k, k, 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k,
config.layout.block_n, config.layout.block_k, 1, 0);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
0);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.gmem_a_ptr = nullptr,
.gmem_b_ptr = nullptr,
.grouped_layout = nullptr,
.tensor_map_buffer = nullptr,
.tensor_map_a_base = tensor_map_a,
.tensor_map_b_base = tensor_map_b,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n,
const std::vector<int>& ks, const torch::Tensor& ks_tensor,
const torch::Tensor& tensor_map_buffer,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b,
const std::string& compiled_dims) {
DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
// TODO: refactor with the mk alignment function
const auto num_groups = static_cast<int>(ks.size());
int first_k = 0, sum_k = 0, sum_sf_k = 0, max_k = 0;
for (int i = 0; i < num_groups; ++ i) {
if (first_k == 0 and ks[i] != 0)
first_k = ks[i];
sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128);
max_k = std::max(max_k, ks[i]);
DG_HOST_ASSERT(ks[i] % 128 == 0);
}
// Get config using max K for better performance
const auto desc = GemmDesc {
.gemm_type = GemmType::KGroupedContiguous,
.kernel_type = KernelType::Kernel1D1D,
.m = m, .n = n, .k = sum_k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k,
config.storage_config.load_block_m,
config.layout.block_k, first_k, 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k,
config.storage_config.load_block_n,
config.layout.block_k, first_k, 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128,
config.layout.block_m, config.layout.block_k, 1, 0);
const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128,
config.layout.block_n, config.layout.block_k, 1, 0);
const auto tensor_map_cd = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
// Launch
const SM90FP8Gemm1D1DRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.gmem_a_ptr = a.data_ptr(),
.gmem_b_ptr = b.data_ptr(),
.grouped_layout = ks_tensor.data_ptr(),
.tensor_map_buffer = tensor_map_buffer.data_ptr(),
.tensor_map_a_base = tensor_map_a_base,
.tensor_map_b_base = tensor_map_b_base,
.tensor_map_sfa = tensor_map_sfa,
.tensor_map_sfb = tensor_map_sfb,
.tensor_map_cd = tensor_map_cd,
};
const auto code = SM90FP8Gemm1D1DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code);
SM90FP8Gemm1D1DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,361 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../heuristics/sm90.hpp"
#include "epilogue.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime<SM90FP8Gemm1D2DRuntime> {
public:
struct Args {
GemmDesc gemm_desc;
GemmConfig gemm_config;
LaunchArgs launch_args;
// TODO: move this into `gemm_desc`
const std::optional<std::string>& epilogue_type;
cute::UMMA::Major major_sfb;
void *sfb, *grouped_layout;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
CUtensorMap tensor_map_sfa;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_fp8_gemm_1d2d_impl<
{},
{}, {}, {},
{},
{}, {}, {},
{}, {}, {},
{},
{}, {},
{}, {},
{}, {},
{}
>);
}};
)",
// TODO: add CD dtype
to_string(args.major_sfb),
get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims),
get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims),
args.gemm_desc.num_groups,
args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k,
args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode,
args.gemm_config.pipeline_config.num_stages,
args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads,
args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1,
args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type),
get_default_epilogue_type(args.epilogue_type));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.sfb, args.grouped_layout,
args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k,
args.tensor_map_a, args.tensor_map_b,
args.tensor_map_d, args.tensor_map_sfa));
}
};
static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
const std::string& compiled_dims,
const std::optional<std::string>& epilogue_type = std::nullopt) {
DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto desc = GemmDesc {
.gemm_type = GemmType::Normal,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = 1,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), 1,
config.storage_config.swizzle_b_mode);
const auto tensor_map_d = make_tma_cd_desc(d, m, static_cast<int>(d.size(-1)),
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = epilogue_type,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& m_indices,
const int& num_groups, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
const std::string& compiled_dims,
const bool& use_psum_layout,
const std::optional<int>& expected_m_for_psum_layout) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto gemm_type = use_psum_layout ?
GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous;
// Only psum layout can use expected m
if (expected_m_for_psum_layout)
DG_HOST_ASSERT(use_psum_layout);
const auto desc = GemmDesc {
.gemm_type = gemm_type,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m_for_psum_layout.value_or(m),
.expected_n = n, .expected_k = k,
.expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), 1,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_d = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), 1,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, 1, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = m_indices.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const torch::Tensor& d,
const torch::Tensor& masked_m,
const int& num_groups, const int& m, const int& n, const int& k,
const int& expected_m,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto desc = GemmDesc {
.gemm_type = GemmType::MGroupedMasked,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = num_groups,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = false,
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims,
.expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k,
config.storage_config.load_block_m,
config.layout.block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(major_a))), num_groups,
config.storage_config.swizzle_a_mode);
const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k,
config.storage_config.load_block_n,
config.layout.block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(major_b))), num_groups,
config.storage_config.swizzle_b_mode);
const auto tensor_map_d = make_tma_cd_desc(d, m, n,
config.storage_config.store_block_m,
config.storage_config.store_block_n,
static_cast<int>(d.stride(-2)), num_groups,
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, num_groups, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = masked_m.data_ptr(),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa,
const torch::Tensor& b, const torch::Tensor& sfb,
const std::optional<torch::Tensor>& c,
const torch::Tensor& d,
const int& batch_size, const int& m, const int& n, const int& k,
const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb,
const std::string& compiled_dims) {
DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16);
DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K);
const auto desc = GemmDesc {
.gemm_type = GemmType::Batched,
.kernel_type = KernelType::Kernel1D2D,
.m = m, .n = n, .k = k, .num_groups = batch_size,
.a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(),
.cd_dtype = d.scalar_type(),
.major_a = major_a, .major_b = major_b,
.with_accumulation = c.has_value(),
.num_sms = device_runtime->get_num_sms(),
.tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims
};
const auto config = get_best_config<SM90ArchSpec>(desc);
// Requires no TMA splits
DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k);
DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k);
const int load_block_m = config.storage_config.load_block_m;
const auto tensor_map_a = make_tma_3d_desc(a, k, m, batch_size,
config.layout.block_k, load_block_m, 1,
a.stride(1),
a.stride(0),
config.storage_config.swizzle_a_mode);
const int load_block_n = config.storage_config.load_block_n;
const auto tensor_map_b = make_tma_3d_desc(b, k, n, batch_size,
config.layout.block_k, load_block_n, 1,
b.stride(1),
b.stride(0),
config.storage_config.swizzle_b_mode);
const int store_block_m = config.storage_config.store_block_m;
const int store_block_n = config.storage_config.store_block_n;
const auto tensor_map_d = make_tma_3d_desc(d, n, m, batch_size,
store_block_n, store_block_m, 1,
d.stride(1), d.stride(0),
config.storage_config.swizzle_cd_mode);
const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k,
config.layout.block_m, config.layout.block_k, batch_size, 0);
// Launch
const SM90FP8Gemm1D2DRuntime::Args& args = {
.gemm_desc = desc,
.gemm_config = config,
.launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads,
config.pipeline_config.smem_size,
config.layout.get_cluster_size()),
.epilogue_type = std::nullopt,
.major_sfb = major_sfb,
.sfb = sfb.data_ptr(),
.grouped_layout = nullptr,
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.tensor_map_sfa = tensor_map_sfa,
};
const auto code = SM90FP8Gemm1D2DRuntime::generate(args);
const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,152 @@
#pragma once
#include <torch/python.h>
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SM90BF16HCPrenormGemmRuntime final: public LaunchRuntime<SM90BF16HCPrenormGemmRuntime> {
public:
struct Args {
int m, n, k;
int block_m, block_n, block_k;
int num_splits;
int swizzle_cd_mode;
int num_stages;
int num_math_threads, num_tma_threads;
LaunchArgs launch_args;
CUtensorMap tensor_map_a;
CUtensorMap tensor_map_b;
CUtensorMap tensor_map_d;
float* sqr_sum;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm90_tf32_hc_prenorm_gemm_impl<
{}, {},
{}, {}, {},
{},
{},
{},
{}, {}
>);
}};
)",
args.n, args.k,
args.block_m, args.block_n, args.block_k,
args.num_splits,
args.swizzle_cd_mode,
args.num_stages,
args.num_math_threads, args.num_tma_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
// TODO: optimize `args` copy
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum));
}
};
static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const torch::Tensor& sqr_sum,
const int& m, const int& n, const int& k,
const int& num_splits) {
constexpr int block_m = 64;
constexpr int block_k = 64;
constexpr int num_math_threads = 128;
constexpr int num_tma_threads = 128;
constexpr int num_threads = num_math_threads + num_tma_threads;
const int block_n = align(n, 16);
DG_HOST_ASSERT(n <= block_n);
// Only support small N for now
DG_HOST_ASSERT(n <= 32 and n % 8 == 0);
DG_HOST_ASSERT(k % block_k == 0);
const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float));
const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k,
block_m, block_k,
static_cast<int>(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, a.element_size()), 0,
true);
const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k,
block_n, block_k,
static_cast<int>(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1,
get_swizzle_mode(block_k, b.element_size()), 0,
true);
const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n,
block_m, block_n,
static_cast<int>(d.stride(-2)), 1,
swizzle_cd_mode)
: make_tma_3d_desc(d, n, m, num_splits,
block_n, block_m, 1,
static_cast<int>(d.stride(-2)),
static_cast<int>(d.stride(-3)),
swizzle_cd_mode);
// Calculate stages
int num_stages = 12, smem_size = 0;
while (num_stages > 0) {
const int smem_a_per_stage = block_m * block_k * static_cast<int>(sizeof(nv_bfloat16));
const int smem_b_per_stage = block_n * block_k * static_cast<int>(sizeof(float));
const int smem_cd = block_m * swizzle_cd_mode;
const int smem_barriers = num_stages * 2 * 8;
smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages +
smem_cd + smem_barriers;
if (smem_size <= SM90ArchSpec::smem_capacity)
break;
-- num_stages;
}
DG_HOST_ASSERT(num_stages > 0);
// Print configs
if (get_env("DG_JIT_DEBUG", 0)) {
printf("M: %d, N: %d, K: %d -> "
"block M: %d, block N: %d, block K: %d, split K: %d"
"stages: %d, shared memory: %d, swizzle CD: %d\n",
m, n, k, block_m, block_n, block_k, num_splits,
num_stages, smem_size, swizzle_cd_mode);
}
smem_size = SM90ArchSpec::smem_capacity;
// Launch
const SM90BF16HCPrenormGemmRuntime::Args& args = {
.m = m, .n = n, .k = k,
.block_m = block_m, .block_n = block_n, .block_k = block_k,
.num_splits = num_splits,
.swizzle_cd_mode = swizzle_cd_mode,
.num_stages = num_stages,
.num_math_threads = num_math_threads,
.num_tma_threads = num_tma_threads,
.launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size),
.tensor_map_a = tensor_map_a,
.tensor_map_b = tensor_map_b,
.tensor_map_d = tensor_map_d,
.sqr_sum = sqr_sum.data_ptr<float>()
};
const auto code = SM90BF16HCPrenormGemmRuntime::generate(args);
const auto runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code);
SM90BF16HCPrenormGemmRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,81 @@
#pragma once
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../../utils/exception.hpp"
namespace deep_gemm {
class SMXXCleanLogitsRuntime final: public LaunchRuntime<SMXXCleanLogitsRuntime> {
public:
struct Args {
int next_n;
int seq_len;
int seq_len_kv;
uint64_t stride_logits;
int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
void* logits;
at::ScalarType logits_dtype;
int block_kv;
int num_warps;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/smxx_clean_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&smxx_clean_logits<
{}, {}, {}, {}
>);
}};
)", args.next_n, args.block_kv, args.num_warps, to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.seq_len, args.seq_len_kv, static_cast<int64_t>(args.stride_logits),
args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits
));
}
};
static void smxx_clean_logits(const torch::Tensor& logits,
const std::optional<torch::Tensor>& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const int& next_n,
const int& seq_len, const int& seq_len_kv,
const uint64_t &stride_logits) {
const int block_kv = 8192;
const int num_warps = 8;
const int smem_size = block_kv * sizeof(float);
// Launch
const SMXXCleanLogitsRuntime::Args& args = {
.next_n = next_n,
.seq_len = seq_len,
.seq_len_kv = seq_len_kv,
.stride_logits = stride_logits,
.cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr<int>() : nullptr,
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr(),
.logits_dtype = logits.scalar_type(),
.block_kv = block_kv,
.num_warps = num_warps,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_warps * 32, smem_size)
};
const auto code = SMXXCleanLogitsRuntime::generate(args);
const auto runtime = compiler->build("smxx_clean_logits", code);
SMXXCleanLogitsRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,151 @@
#pragma once
#include <cublasLt.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDADataType.h>
#include <cute/arch/mma_sm100_umma.hpp>
#include "../../jit/device_runtime.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/compatibility.hpp"
namespace deep_gemm {
static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld,
const std::optional<int>& batch_count = std::nullopt,
const std::optional<int>& batch_offset = std::nullopt) {
cublasLtMatrixLayout_t layout;
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld));
if (batch_count.has_value()) {
DG_HOST_ASSERT(batch_offset.has_value());
const int64_t batch_offset_int64 = batch_offset.value();
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value())));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64)));
}
return layout;
}
static void call_cublaslt_api(const cublasOperation_t& trans_a,
const cublasOperation_t& trans_b,
const cublasLtMatrixLayout_t& layout_a,
const cublasLtMatrixLayout_t& layout_b,
const cublasLtMatrixLayout_t& layout_d,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& d,
const bool& accumulate) {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
cudaDataType_t scale_type = CUDA_R_32F;
// Operation description
cublasLtMatmulDesc_t desc;
DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b)));
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type)));
#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE
const int math_sms = device_runtime->get_num_sms();
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms)));
#endif
#if DG_FP8_COMPATIBLE and DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE
bool fp8_fast_accumulate = false;
if (a.scalar_type() == torch::kFloat8_e4m3fn)
DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate)));
#endif
// Get cuBLASLt handle, workspace, and stream
const auto handle = device_runtime->get_cublaslt_handle();
const auto workspace = device_runtime->get_cublaslt_workspace();
const auto workspace_bytes = workspace.nbytes();
const auto stream = at::cuda::getCurrentCUDAStream();
// Algorithm selection
cublasLtMatmulPreference_t pref;
cublasLtMatmulHeuristicResult_t heuristic;
int num_heuristic_results = 0;
uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE;
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_bytes, sizeof(workspace_bytes)));
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK,
&reduction_scheme_mask, sizeof(reduction_scheme_mask)));
DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d,
pref, 1, &heuristic, &num_heuristic_results));
DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM");
// Call: D = alpha * (A @ B) + beta * C
const float alpha = 1.0, beta = accumulate ? 1.0 : 0.0;
DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle
desc, // Operation description
&alpha, // Alpha
b.data_ptr(), layout_a, // A
a.data_ptr(), layout_b, // B
&beta, // Beta
d.data_ptr(), layout_d, // C
d.data_ptr(), layout_d, // D
&heuristic.algo, // Algorithm
workspace.data_ptr(), workspace_bytes, // Workspace
stream)); // Stream
// Free memory
DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b));
DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d));
DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc));
}
static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs,
const torch::Tensor& out,
const int& m, const int& n, const int& k,
const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major,
const bool& accumulate) {
const auto trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N;
const auto trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T;
// Matrix layouts
const auto cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type());
const auto cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type());
const auto cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type());
const auto layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0))
: get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1));
const auto layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0))
: get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1));
const auto layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, accumulate);
}
static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto m = d, n = b, k = r;
const auto trans_a = CUBLAS_OP_T;
const auto trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0));
const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out,
const int& b, const int& h, const int& r, const int& d) {
const auto m = r, n = b, k = d;
const auto trans_a = CUBLAS_OP_N;
const auto trans_b = CUBLAS_OP_N;
// Matrix layouts
const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0));
const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1));
const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1));
call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,328 @@
#pragma once
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../heuristics/sm90.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SMXXFP8MQALogitsRuntime final: public LaunchRuntime<SMXXFP8MQALogitsRuntime> {
public:
struct Args {
int seq_len;
int seq_len_kv;
int max_seqlen_k;
int stride_logits;
int num_heads, head_dim;
bool is_compressed_logits;
int num_q_stages;
int num_kv_stages;
int block_q;
int block_kv;
int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
void* logits;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_scales;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm{}_fp8_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_mqa_logits<
{}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)", arch, arch,
args.num_heads, args.head_dim,
args.is_compressed_logits,
args.block_q, args.block_kv,
args.num_q_stages, args.num_kv_stages,
args.launch_args.grid_dim.first,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.seq_len, args.seq_len_kv,
args.max_seqlen_k, args.stride_logits,
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
args.logits,
args.tensor_map_q, args.tensor_map_kv,
args.tensor_map_kv_scales, args.tensor_map_weights
));
}
};
static void smxx_fp8_mqa_logits(const torch::Tensor& q,
const torch::Tensor& kv, const torch::Tensor& kv_scales,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const torch::Tensor& logits,
const at::ScalarType& logits_dtype,
const int& seq_len, const int& seq_len_kv,
const int& max_seqlen_k, const int& stride_logits,
const int& num_heads, const int& head_dim,
const int& block_q, const int& block_kv) {
constexpr int num_specialized_threads = 128;
constexpr int num_q_stages = 3, num_kv_stages = 3;
const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512);
// Use compressed logits format when max_seqlen_k is specified
const bool is_compressed_logits = (max_seqlen_k > 0);
// Construct TMAs
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
head_dim, block_q * num_heads, head_dim, head_dim);
const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
head_dim, block_kv, head_dim, head_dim);
// According to the driver API, the minimal alignment is 256 bytes
// So it is safe for us to do a 16-byte OOB
const auto tensor_map_kv_scales = make_tma_2d_desc(kv_scales,
get_tma_aligned_size(seq_len_kv, static_cast<int>(kv_scales.element_size())),
1, block_kv, 1, 0, 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
num_heads, block_q, num_heads, 0);
// Calculate shared memory size
int smem_size = 0;
const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(weights.element_size());
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv.element_size());
const int kv_scale_size_per_stage = block_kv * static_cast<int>(kv_scales.element_size());
smem_size += num_q_stages * smem_q_size_per_stage;
smem_size += num_kv_stages * smem_kv_size_per_stage;
smem_size += num_q_stages * smem_weight_size_per_stage;
smem_size += num_kv_stages * kv_scale_size_per_stage;
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8;
smem_size += 4;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SMXXFP8MQALogitsRuntime::Args args = {
.seq_len = seq_len,
.seq_len_kv = seq_len_kv,
.max_seqlen_k = max_seqlen_k,
.stride_logits = stride_logits,
.num_heads = num_heads, .head_dim = head_dim,
.is_compressed_logits = is_compressed_logits,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.block_q = block_q,
.block_kv = block_kv,
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr(),
.tensor_map_q = tensor_map_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_kv_scales = tensor_map_kv_scales,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SMXXFP8MQALogitsRuntime::generate(args);
const auto runtime = compiler->build("smxx_fp8_mqa_logits", code);
SMXXFP8MQALogitsRuntime::launch(runtime, args);
}
class SM100FP4MQALogitsRuntime final: public LaunchRuntime<SM100FP4MQALogitsRuntime> {
public:
struct Args {
int seq_len;
int seq_len_kv;
int max_seqlen_k;
int stride_logits;
int num_heads, head_dim;
bool is_compressed_logits;
int num_q_stages;
int num_kv_stages;
int block_q;
int block_kv;
int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
void* logits;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_sf_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_sf_kv;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp4_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp4_mqa_logits<
{}, {},
{},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)", args.num_heads, args.head_dim,
args.is_compressed_logits,
args.block_q, args.block_kv,
args.num_q_stages, args.num_kv_stages,
args.launch_args.grid_dim.first,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.seq_len, args.seq_len_kv,
args.max_seqlen_k, args.stride_logits,
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
args.logits,
args.tensor_map_q, args.tensor_map_sf_q,
args.tensor_map_kv, args.tensor_map_sf_kv,
args.tensor_map_weights
));
}
};
static void sm100_fp4_mqa_logits(const torch::Tensor& q, const torch::Tensor& sf_q,
const torch::Tensor& kv, const torch::Tensor& sf_kv,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const torch::Tensor& logits,
const at::ScalarType& logits_dtype,
const int& seq_len, const int& seq_len_kv,
const int& max_seqlen_k, const int& stride_logits,
const int& num_heads, const int& head_dim,
const int& block_q, const int& block_kv) {
constexpr int num_specialized_threads = 128;
const int num_math_threads = 2 * 128;
constexpr int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3;
// Use compressed logits format when max_seqlen_k is specified
const bool is_compressed_logits = (max_seqlen_k > 0);
// Construct TMAs
// `head_dim` must be 128 for 64B swizzling
DG_HOST_ASSERT(head_dim == 128);
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
head_dim, block_q * num_heads,
static_cast<int>(q.stride(1)),
head_dim / 2, 0, false, false);
const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, seq_len,
num_heads, block_q,
static_cast<int>(sf_q.stride(0)), 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
num_heads, block_q,
static_cast<int>(weights.stride(0)), 0);
const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
head_dim, block_kv,
static_cast<int>(kv.stride(0)),
head_dim / 2, 0, false, false);
// According to the driver API, the minimal alignment is 256 bytes
// So it is safe for us to do a 16-byte OOB
const auto tensor_map_sf_kv = make_tma_2d_desc(sf_kv,
get_tma_aligned_size(seq_len_kv, static_cast<int>(sf_kv.element_size())), 1,
block_kv, 1, 0, 0);
// Calculate shared memory size
const int smem_q_size_per_stage = block_q * num_heads * head_dim / 2;
const int smem_sf_q_size_per_stage = align(block_q * num_heads, 128) * sizeof(int);
const int smem_kv_size_per_stage = block_kv * head_dim / 2;
const int smem_sf_kv_size_per_stage = align(block_kv, 128) * sizeof(int);
const int smem_weight_size_per_stage = block_q * num_heads * sizeof(float);
const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8;
const int smem_tmem_ptr = 4;
const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) +
smem_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SM100FP4MQALogitsRuntime::Args args = {
.seq_len = seq_len,
.seq_len_kv = seq_len_kv,
.max_seqlen_k = max_seqlen_k,
.stride_logits = stride_logits,
.num_heads = num_heads, .head_dim = head_dim,
.is_compressed_logits = is_compressed_logits,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.block_q = block_q,
.block_kv = block_kv,
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr(),
.tensor_map_q = tensor_map_q,
.tensor_map_sf_q = tensor_map_sf_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_sf_kv = tensor_map_sf_kv,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SM100FP4MQALogitsRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp4_mqa_logits", code);
SM100FP4MQALogitsRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,463 @@
#pragma once
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQALogitsMetadataRuntime> {
public:
struct Args {
int aligned_batch_size;
int split_kv;
int num_sms;
bool is_varlen;
int batch_size;
int next_n;
bool is_context_lens_2d;
int* context_lens;
int* indices;
int* schedule_metadata;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sched::smxx_paged_mqa_logits_metadata<
{}, {}, {}, {}
>);
}};
)", args.aligned_batch_size, args.split_kv, args.num_sms, args.is_varlen ? "true" : "false");
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
args.next_n,
args.is_context_lens_2d,
args.context_lens,
args.indices,
args.schedule_metadata
));
}
};
static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
const torch::Tensor& schedule_metadata,
const int& batch_size, const int& next_n,
const int& block_kv, const int& num_sms,
const bool& is_context_lens_2d,
const bool& is_varlen, const int* indices_ptr) {
constexpr int split_kv = 256;
constexpr int num_threads = 32;
const int aligned_batch_size = align(batch_size, 32);
DG_HOST_ASSERT(split_kv % block_kv == 0);
// Shared memory: prefix_sum[kAlignedBatchSize] + varlen_atom_token_start/context_len[kAlignedBatchSize] + varlen_num_atoms
const int smem_size = (3 * aligned_batch_size + 1) * static_cast<int>(sizeof(int));
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SMXXPagedMQALogitsMetadataRuntime::Args& args = {
.aligned_batch_size = aligned_batch_size,
.split_kv = split_kv,
.num_sms = num_sms,
.is_varlen = is_varlen,
.batch_size = batch_size,
.next_n = next_n,
.is_context_lens_2d = is_context_lens_2d,
.context_lens = context_lens.data_ptr<int>(),
.indices = const_cast<int*>(indices_ptr),
.schedule_metadata = schedule_metadata.data_ptr<int>(),
.launch_args = LaunchArgs(1, num_threads, smem_size)
};
const auto code = SMXXPagedMQALogitsMetadataRuntime::generate(args);
const auto runtime = compiler->build("smxx_paged_mqa_logits_metadata", code);
SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args);
}
class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime<SMXXFP8PagedMQALogitsRuntime> {
public:
struct Args {
int batch_size;
int next_n;
int num_heads;
int head_dim;
int block_kv;
bool is_context_lens_2d;
bool is_varlen;
int block_table_stride;
int logits_stride;
int num_q_stages;
int num_kv_stages;
int split_kv;
int* context_lens;
void* logits;
int* block_table;
int* indices;
int* schedule_meta;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_scales;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_paged_mqa_logits<
{}, {},
{}, {},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)", arch, arch,
args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.is_context_lens_2d, args.is_varlen ? "true" : "false",
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
args.logits_stride, args.block_table_stride,
args.context_lens, args.logits,
args.block_table, args.indices, args.schedule_meta,
args.tensor_map_q, args.tensor_map_kv,
args.tensor_map_kv_scales, args.tensor_map_weights
));
}
};
static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& kv_cache,
const torch::Tensor& kv_cache_scales,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& logits,
const torch::Tensor& block_table,
const torch::Tensor& indices,
const torch::Tensor& schedule_meta,
const at::ScalarType& logits_dtype,
const int& batch_size, const int& next_n,
const int& num_heads, const int& head_dim,
const int& num_kv_blocks, const int& block_kv,
const bool& is_context_lens_2d,
const bool& is_varlen,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
const int& split_kv) {
const int num_specialized_threads = 128;
const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64);
const int num_math_warp_groups = split_kv / mma_m;
const int num_math_threads = num_math_warp_groups * 128;
const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3);
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);
// Construct TMAs
const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1;
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
head_dim, next_n_atom * num_heads,
static_cast<int>(q.stride(2)),
head_dim);
const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
head_dim, block_kv, 1,
static_cast<int>(kv_cache.stride(1)),
static_cast<int>(kv_cache.stride(0)),
head_dim);
const auto tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks,
block_kv, 1,
static_cast<int>(kv_cache_scales.stride(0)), 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n,
num_heads, next_n_atom,
static_cast<int>(weights.stride(0)), 0);
// Calculate shared memory size
int smem_size = 0;
if (device_runtime->get_arch_major() == 9) {
const int swizzle_alignment = head_dim * 8;
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(next_n == 1 or next_n == 2);
} else {
const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_kv_size_per_stage = split_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int smem_kv_scale_size_per_stage = split_kv * static_cast<int>(kv_cache_scales.element_size());
const int smem_weight_size_per_stage = next_n_atom * num_heads * static_cast<int>(weights.element_size());
const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8;
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) +
smem_barriers + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
}
// Launch
const SMXXFP8PagedMQALogitsRuntime::Args args = {
.batch_size = batch_size,
.next_n = next_n,
.num_heads = num_heads,
.head_dim = head_dim,
.block_kv = block_kv,
.is_context_lens_2d = is_context_lens_2d,
.is_varlen = is_varlen,
.block_table_stride = block_table_stride,
.logits_stride = logits_stride,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.split_kv = split_kv,
.context_lens = context_lens.data_ptr<int>(),
.logits = logits.data_ptr(),
.block_table = block_table.data_ptr<int>(),
.indices = is_varlen ? indices.data_ptr<int>() : nullptr,
.schedule_meta = schedule_meta.data_ptr<int>(),
.tensor_map_q = tensor_map_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_kv_scales = tensor_map_kv_scales,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_sms,
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SMXXFP8PagedMQALogitsRuntime::generate(args);
const auto runtime = compiler->build("smxx_fp8_paged_mqa_logits", code);
SMXXFP8PagedMQALogitsRuntime::launch(runtime, args);
}
class SM100FP4PagedMQALogitsRuntime final: public LaunchRuntime<SM100FP4PagedMQALogitsRuntime> {
public:
struct Args {
int batch_size;
int next_n;
int num_heads;
int head_dim;
int block_kv;
bool is_context_lens_2d;
bool is_varlen;
int block_table_stride;
int logits_stride;
int num_q_stages;
int num_kv_stages;
int split_kv;
int* context_lens;
void* logits;
int* block_table;
int* indices;
int* schedule_meta;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_sf_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_sf_kv;
CUtensorMap tensor_map_weights;
at::ScalarType logits_dtype;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm100_fp4_paged_mqa_logits<
{}, {},
{}, {},
{}, {},
{}, {},
{},
{}, {},
{}
>);
}};
)", args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.is_context_lens_2d, args.is_varlen ? "true" : "false",
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads,
to_string(args.logits_dtype));
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
args.logits_stride, args.block_table_stride,
args.context_lens, args.logits,
args.block_table, args.indices, args.schedule_meta,
args.tensor_map_q, args.tensor_map_sf_q,
args.tensor_map_kv, args.tensor_map_sf_kv,
args.tensor_map_weights
));
}
};
static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& sf_q,
const torch::Tensor& kv_cache,
const torch::Tensor& kv_cache_sf,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& logits,
const torch::Tensor& block_table,
const torch::Tensor& indices,
const torch::Tensor& schedule_meta,
const at::ScalarType& logits_dtype,
const int& batch_size, const int& next_n,
const int& num_heads, const int& head_dim,
const int& num_kv_blocks, const int& block_kv,
const bool& is_context_lens_2d,
const bool& is_varlen,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
const int& split_kv) {
const int num_specialized_threads = 128;
const int num_math_threads = 2 * 128;
DG_HOST_ASSERT(split_kv == 256 and logits_stride % split_kv == 0);
// TODO: tuning num_stages
const int num_q_stages = 3, num_kv_stages = 10, num_tmem_stages = 3;
const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1;
// `head_dim` must be 128 for 64B swizzling
DG_HOST_ASSERT(head_dim == 128);
// Using 2D TMA as tensor q is asserted contiguous
const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
head_dim, next_n_atom * num_heads,
static_cast<int>(q.stride(2)),
head_dim / 2, 0, false, false);
// NOTES: `sf_q` is a 3D tensor, while `weights` is a 2D tensor
const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, batch_size * next_n,
num_heads, next_n_atom,
static_cast<int>(sf_q.stride(1)), 0);
const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n,
num_heads, next_n_atom,
static_cast<int>(weights.stride(0)), 0);
const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
head_dim, block_kv, 1,
static_cast<int>(kv_cache.stride(1)),
static_cast<int>(kv_cache.stride(0)),
head_dim / 2, 0, false, false);
const auto tensor_map_sf_kv = make_tma_2d_desc(kv_cache_sf, block_kv, num_kv_blocks,
block_kv, 1,
static_cast<int>(kv_cache_sf.stride(0)), 0);
// Calculate shared memory size
const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim / 2;
const int smem_sf_q_size_per_stage = align(next_n_atom * num_heads, 128) * sizeof(int);
const int smem_kv_size_per_stage = split_kv * head_dim / 2;
const int smem_sf_kv_size_per_stage = align(split_kv, 128) * sizeof(int);
const int smem_weight_size_per_stage = next_n_atom * num_heads * sizeof(float);
const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8;
const int smem_tmem_ptr = 4;
const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) +
smem_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SM100FP4PagedMQALogitsRuntime::Args args = {
.batch_size = batch_size,
.next_n = next_n,
.num_heads = num_heads,
.head_dim = head_dim,
.block_kv = block_kv,
.is_context_lens_2d = is_context_lens_2d,
.is_varlen = is_varlen,
.block_table_stride = block_table_stride,
.logits_stride = logits_stride,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.split_kv = split_kv,
.context_lens = context_lens.data_ptr<int>(),
.logits = logits.data_ptr(),
.block_table = block_table.data_ptr<int>(),
.indices = is_varlen ? indices.data_ptr<int>() : nullptr,
.schedule_meta = schedule_meta.data_ptr<int>(),
.tensor_map_q = tensor_map_q,
.tensor_map_sf_q = tensor_map_sf_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_sf_kv = tensor_map_sf_kv,
.tensor_map_weights = tensor_map_weights,
.logits_dtype = logits_dtype,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_sms,
num_specialized_threads + num_math_threads,
smem_size)
};
const auto code = SM100FP4PagedMQALogitsRuntime::generate(args);
const auto runtime = compiler->build("sm100_fp4_paged_mqa_logits", code);
SM100FP4PagedMQALogitsRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,164 @@
#pragma once
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../heuristics/sm90.hpp"
#include "../heuristics/sm100.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SMXXFP8MQALogitsRuntime final: public LaunchRuntime<SMXXFP8MQALogitsRuntime> {
public:
struct Args {
int seq_len;
int seq_len_kv;
int max_seqlen_k;
int stride_logits;
int num_heads, head_dim;
bool is_compressed_logits;
int num_q_stages;
int num_kv_stages;
int block_q;
int block_kv;
int* cu_seq_len_k_start;
int* cu_seq_len_k_end;
float* logits;
float softmax_scale;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_scales;
CUtensorMap tensor_map_weights;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto& arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm{}_fp8_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_mqa_logits<
{}, {},
{},
{}, {},
{}, {},
{}, {}
>);
}};
)", arch, arch,
args.num_heads, args.head_dim,
args.is_compressed_logits,
args.block_q, args.block_kv,
args.num_q_stages, args.num_kv_stages,
args.num_specialized_threads, args.num_math_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.seq_len, args.seq_len_kv,
args.max_seqlen_k, static_cast<int64_t>(args.stride_logits),
args.cu_seq_len_k_start, args.cu_seq_len_k_end,
args.logits,
args.tensor_map_q, args.tensor_map_kv,
args.tensor_map_kv_scales, args.tensor_map_weights
));
}
};
static void smxx_fp8_mqa_logits(const torch::Tensor& q,
const torch::Tensor& kv, const torch::Tensor& kv_scales,
const torch::Tensor& weights,
const torch::Tensor& cu_seq_len_k_start,
const torch::Tensor& cu_seq_len_k_end,
const torch::Tensor& logits,
const int& seq_len, const int& seq_len_kv,
const int& max_seqlen_k, const int& stride_logits,
const int& num_heads, const int& head_dim,
const int& seq_len_alignment) {
constexpr int block_qh = 128;
constexpr int block_kv = 256;
constexpr int num_specialized_threads = 128;
constexpr int num_q_stages = 3, num_kv_stages = 3;
const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512);
const int block_q = block_qh / num_heads;
DG_HOST_ASSERT(block_qh % num_heads == 0);
DG_HOST_ASSERT(seq_len_alignment % block_q == 0);
// Use compressed logits format when max_seqlen_k is specified
const bool is_compressed_logits = (max_seqlen_k > 0);
// Construct TMAs
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads,
head_dim, block_qh, head_dim, head_dim);
const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv,
head_dim, block_kv, head_dim, head_dim);
// According to the driver API, the minimal alignment is 256 bytes
// So it is safe for us to do a 16-byte OOB
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales,
get_tma_aligned_size(seq_len_kv, static_cast<int>(kv_scales.element_size())),
1, block_kv, 1, 0, 0);
const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len,
num_heads, block_q, num_heads, 0);
// Calculate shared memory size
int smem_size = 0;
const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_weight_size_per_stage = block_q * num_heads * static_cast<int>(weights.element_size());
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv.element_size());
const int kv_scale_size_per_stage = block_kv * static_cast<int>(kv_scales.element_size());
smem_size += num_q_stages * smem_q_size_per_stage;
smem_size += num_kv_stages * smem_kv_size_per_stage;
smem_size += num_q_stages * smem_weight_size_per_stage;
smem_size += num_kv_stages * kv_scale_size_per_stage;
smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8;
smem_size += 4;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SMXXFP8MQALogitsRuntime::Args& args = {
.seq_len = seq_len,
.seq_len_kv = seq_len_kv,
.max_seqlen_k = max_seqlen_k,
.stride_logits = stride_logits,
.num_heads = num_heads, .head_dim = head_dim,
.is_compressed_logits = is_compressed_logits,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.block_q = block_q,
.block_kv = block_kv,
.cu_seq_len_k_start = cu_seq_len_k_start.data_ptr<int>(),
.cu_seq_len_k_end = cu_seq_len_k_end.data_ptr<int>(),
.logits = logits.data_ptr<float>(),
.tensor_map_q = tensor_map_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_kv_scales = tensor_map_kv_scales,
.tensor_map_weights = tensor_map_weights,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(device_runtime->get_num_sms(),
num_specialized_threads + num_math_threads,
smem_size)
};
const auto& code = SMXXFP8MQALogitsRuntime::generate(args);
const auto& runtime = compiler->build("smxx_fp8_mqa_logits", code);
SMXXFP8MQALogitsRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,265 @@
#pragma once
#include "../../jit/compiler.hpp"
#include "../../jit/device_runtime.hpp"
#include "../../jit/kernel_runtime.hpp"
#include "../heuristics/sm90.hpp"
#include "runtime_utils.hpp"
namespace deep_gemm {
class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime<SMXXPagedMQALogitsMetadataRuntime> {
public:
struct Args {
int aligned_batch_size;
int split_kv;
int num_sms;
int batch_size;
int next_n;
bool is_context_lens_2d;
int* context_lens;
int* schedule_metadata;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
const auto& arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&smxx_paged_mqa_logits_metadata<
{}, {}, {}
>);
}};
)", arch, args.aligned_batch_size, args.split_kv, args.num_sms);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
args.next_n,
args.is_context_lens_2d,
args.context_lens,
args.schedule_metadata
));
}
};
static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens,
const torch::Tensor& schedule_metadata,
const int& batch_size, const int& next_n,
const int& block_kv, const int& num_sms,
const bool& is_context_lens_2d) {
constexpr int num_math_warpgroups = 4;
constexpr int num_threads = 32;
const int aligned_batch_size = align(batch_size, 32);
const int split_kv = block_kv * num_math_warpgroups;
// Calculate shared memory size
const int smem_size = aligned_batch_size * static_cast<int>(sizeof(int));
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
// Launch
const SMXXPagedMQALogitsMetadataRuntime::Args& args = {
.aligned_batch_size = aligned_batch_size,
.split_kv = split_kv,
.num_sms = num_sms,
.batch_size = batch_size,
.next_n = next_n,
.is_context_lens_2d = is_context_lens_2d,
.context_lens = context_lens.data_ptr<int>(),
.schedule_metadata = schedule_metadata.data_ptr<int>(),
.launch_args = LaunchArgs(1, num_threads, smem_size)
};
const auto& code = SMXXPagedMQALogitsMetadataRuntime::generate(args);
const auto& runtime = compiler->build("smxx_paged_mqa_logits_metadata", code);
SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args);
}
class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime<SMXXFP8PagedMQALogitsRuntime> {
public:
struct Args {
int batch_size;
int next_n;
int num_heads;
int head_dim;
int block_kv;
bool is_context_lens_2d;
int block_table_stride;
int logits_stride;
int num_q_stages;
int num_kv_stages;
int split_kv;
int* context_lens;
float* logits;
int* block_table;
int* schedule_meta;
CUtensorMap tensor_map_q;
CUtensorMap tensor_map_kv;
CUtensorMap tensor_map_kv_scales;
CUtensorMap tensor_map_weights;
int num_specialized_threads;
int num_math_threads;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
// TODO: optimize performance by tuning args
// Block sizes are fixed in this kernel
DG_HOST_ASSERT(128 % args.num_heads == 0);
const auto& arch = device_runtime->get_arch(true);
return fmt::format(R"(
#include <deep_gemm/impls/sm{}_fp8_paged_mqa_logits.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&sm{}_fp8_paged_mqa_logits<
{}, {},
{}, {},
{},
{}, {},
{},
{}, {}
>);
}};
)", arch, arch,
args.next_n, args.num_heads,
args.head_dim, args.block_kv,
args.is_context_lens_2d,
args.num_q_stages, args.num_kv_stages,
args.split_kv,
args.num_specialized_threads, args.num_math_threads);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.batch_size,
static_cast<uint64_t>(args.logits_stride),
static_cast<uint64_t>(args.block_table_stride),
args.context_lens, args.logits,
args.block_table, args.schedule_meta,
args.tensor_map_q, args.tensor_map_kv,
args.tensor_map_kv_scales, args.tensor_map_weights
));
}
};
static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q,
const torch::Tensor& kv_cache,
const torch::Tensor& kv_cache_scales,
const torch::Tensor& weights,
const torch::Tensor& context_lens,
const torch::Tensor& logits,
const torch::Tensor& block_table,
const torch::Tensor& schedule_meta,
const int& batch_size, const int& next_n,
const int& num_heads, const int& head_dim,
const int& num_kv_blocks, const int& block_kv,
const bool& is_context_lens_2d,
const int& kv_cache_stride_bytes,
const int& logits_stride,
const int& block_table_stride,
const int& num_sms,
const int& split_kv) {
const int num_specialized_threads = 128;
const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64);
const int num_math_warp_groups = split_kv / mma_m;
const int num_math_threads = num_math_warp_groups * 128;
const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3);
DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0);
// Construct TMAs
DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128);
const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads,
head_dim, next_n * num_heads, head_dim, head_dim);
const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks,
head_dim, block_kv, 1,
head_dim, kv_cache_stride_bytes, head_dim);
// TODO: use 1D TMA
const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks,
block_kv, 1, kv_cache_stride_bytes / static_cast<int>(sizeof(float)), 0);
const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size,
next_n * num_heads, 1, next_n * num_heads, 0);
// Calculate shared memory size
int smem_size = 0;
if (device_runtime->get_arch_major() == 9) {
const int swizzle_alignment = head_dim * 8;
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast<int>(weights.element_size()), swizzle_alignment);
const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment);
const int smem_kv_size_per_stage = block_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast<int>(kv_cache_scales.element_size()), swizzle_alignment);
const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment);
// Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity);
} else {
const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast<int>(q.element_size());
const int smem_kv_size_per_stage = split_kv * head_dim * static_cast<int>(kv_cache.element_size());
const int smem_kv_scale_size_per_stage = split_kv * static_cast<int>(kv_cache_scales.element_size());
const int smem_weight_size_per_stage = next_n * num_heads * static_cast<int>(weights.element_size());
const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8;
const int smem_umma_barriers = num_math_warp_groups * 2 * 8;
const int smem_tmem_ptr = 4;
smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) +
num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) +
smem_barriers + smem_umma_barriers + smem_tmem_ptr;
DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity);
}
// Launch
const SMXXFP8PagedMQALogitsRuntime::Args& args = {
.batch_size = batch_size,
.next_n = next_n,
.num_heads = num_heads,
.head_dim = head_dim,
.block_kv = block_kv,
.is_context_lens_2d = is_context_lens_2d,
.block_table_stride = block_table_stride,
.logits_stride = logits_stride,
.num_q_stages = num_q_stages,
.num_kv_stages = num_kv_stages,
.split_kv = split_kv,
.context_lens = context_lens.data_ptr<int>(),
.logits = logits.data_ptr<float>(),
.block_table = block_table.data_ptr<int>(),
.schedule_meta = schedule_meta.data_ptr<int>(),
.tensor_map_q = tensor_map_q,
.tensor_map_kv = tensor_map_kv,
.tensor_map_kv_scales = tensor_map_kv_scales,
.tensor_map_weights = tensor_map_weights,
.num_specialized_threads = num_specialized_threads,
.num_math_threads = num_math_threads,
.launch_args = LaunchArgs(num_sms,
num_specialized_threads + num_math_threads,
smem_size)
};
const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args);
const auto& runtime = compiler->build("smxx_fp8_paged_mqa_logits", code);
SMXXFP8PagedMQALogitsRuntime::launch(runtime, args);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,267 @@
#pragma once
#include <torch/python.h>
#include "../../jit/kernel_runtime.hpp"
#include "../../jit/compiler.hpp"
#include "../../utils/exception.hpp"
#include "../../utils/format.hpp"
#include "../../utils/math.hpp"
#include "../../utils/layout.hpp"
namespace deep_gemm {
class TransposeFP32Runtime final: public LaunchRuntime<TransposeFP32Runtime> {
public:
struct Args {
int mn, sf_k;
int block_mn;
void *sf, *out;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&transpose_fp32<
{}, {}, {}
>);
}};
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast<uint32_t>(args.mn)));
}
};
class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime<TransposeAndPackFP32IntoUE8M0Runtime> {
public:
struct Args {
int mn, sf_k;
int block_mn;
void *sf, *out;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&transpose_and_pack_fp32_into_ue8m0<
{}, {}, {}
>);
}};
)", args.launch_args.num_threads, args.block_mn, args.sf_k);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast<uint32_t>(args.mn)));
}
};
class PackFP32IntoUE8M0Runtime final: public LaunchRuntime<PackFP32IntoUE8M0Runtime> {
public:
struct Args {
int num_groups, mn, sf_k, packed_sf_k, gran_k;
int block_mn, block_packed_sf_k;
void *sf, *out, *ks;
LaunchArgs launch_args;
};
static std::string generate_impl(const Args& args) {
return fmt::format(R"(
#include <deep_gemm/impls/smxx_layout.cuh>
using namespace deep_gemm;
static void __instantiate_kernel() {{
auto ptr = reinterpret_cast<void*>(&pack_fp32_into_ue8m0<
{}, {}, {}, {}
>);
}};
)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k);
}
static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) {
DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config,
args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k, args.gran_k));
}
};
static std::tuple<int, int, int, int, int, torch::Tensor> preprocess_sf(const torch::Tensor& sf) {
// NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
const auto dim = sf.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat);
const auto batched_sf = dim == 2 ? sf.unsqueeze(0) : sf;
const auto [num_groups, mn, sf_k] = get_shape<3>(batched_sf);
const auto tma_aligned_mn = get_tma_aligned_size(mn, static_cast<int>(sf.element_size()));
return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf};
}
static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) {
const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
// The last kernel already gives a column-major TMA aligned layout
if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn)
return (dim == 2) ? batched_sf.squeeze(0) : batched_sf;
const auto out = torch::empty_strided({num_groups, mn, sf_k},
{tma_aligned_mn * sf_k, 1, tma_aligned_mn},
batched_sf.options());
if (not batched_sf.is_contiguous()) {
// Fallback to PyTorch's slow copy if not contiguous
// ReSharper disable once CppExpressionWithoutSideEffects
out.copy_(batched_sf);
} else {
constexpr int block_mn = 64;
constexpr int num_threads = 512;
const auto smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast<int>(sizeof(float));
const TransposeFP32Runtime::Args& args = {
.mn = mn,
.sf_k = sf_k,
.block_mn = block_mn,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size)
};
const auto code = TransposeFP32Runtime::generate(args);
const auto runtime = compiler->build("transpose_fp32", code);
TransposeFP32Runtime::launch(runtime, args);
}
return (dim == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) {
const auto sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf;
// First, convert into UE8M0 `uint8_t`
const auto ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8);
// Second, make padded packed tensors
const auto [num_groups, mn, k] = get_shape<3>(sf_reshaped);
const auto aligned_mn = get_tma_aligned_size(mn, 4);
const auto aligned_k = align(k, 4);
const auto options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8);
auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options);
// ReSharper disable once CppExpressionWithoutSideEffects
padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor);
padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4});
// Finally, transpose
auto out = torch::empty_strided({num_groups, aligned_mn, aligned_k / 4},
{aligned_mn * (aligned_k / 4), 1, aligned_mn},
at::TensorOptions().device(sf.device()).dtype(torch::kInt32));
out = out.copy_(padded).slice(1, 0, mn);
return (sf.dim() == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) {
const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf);
const auto packed_sf_k = ceil_div(sf_k, 4);
const auto out = torch::empty_strided({num_groups, mn, packed_sf_k},
{packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn},
at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt));
// Launch the kernel
if (batched_sf.is_contiguous()) {
if ((mn * sf_k) % 4 != 0 and num_groups > 1)
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
constexpr int block_mn = 48;
constexpr int num_threads = 512;
const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = {
.mn = mn,
.sf_k = sf_k,
.block_mn = block_mn,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4)
};
const auto code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args);
const auto runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code);
TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args);
} else {
if (mn % 4 != 0 or num_groups > 1)
return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf);
DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn);
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = 1,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = batched_sf.data_ptr(),
.out = out.data_ptr(),
.ks = nullptr,
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto code = PackFP32IntoUE8M0Runtime::generate(args);
const auto runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
}
return (dim == 2) ? out.squeeze(0) : out;
}
static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf,
const torch::Tensor& ks_tensor,
const std::vector<int>& ks,
const int gran_k) {
DG_HOST_ASSERT(gran_k == 32 or gran_k == 128);
const auto [sf_k, mn] = get_shape<2>(sf);
const auto num_groups = static_cast<int>(ks.size());
int ref_sf_k = 0, packed_sf_k = 0;
for (const auto k: ks)
ref_sf_k += ceil_div(k, gran_k), packed_sf_k += ceil_div(k, gran_k * 4);
DG_HOST_ASSERT(sf.is_contiguous());
DG_HOST_ASSERT(ref_sf_k == sf_k);
DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0);
const auto out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt));
constexpr int block_mn = 128;
constexpr int block_packed_sf_k = 16;
constexpr int num_threads = 512;
const PackFP32IntoUE8M0Runtime::Args& args = {
.num_groups = num_groups,
.mn = mn,
.sf_k = sf_k,
.packed_sf_k = packed_sf_k,
.gran_k = gran_k,
.block_mn = block_mn,
.block_packed_sf_k = block_packed_sf_k,
.sf = sf.data_ptr(),
.out = out.data_ptr(),
.ks = ks_tensor.data_ptr(),
.launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads)
};
const auto code = PackFP32IntoUE8M0Runtime::generate(args);
const auto runtime = compiler->build("pack_fp32_into_ue8m0", code);
PackFP32IntoUE8M0Runtime::launch(runtime, args);
return out;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,28 @@
#include <pybind11/pybind11.h>
#include <torch/python.h>
#include "apis/attention.hpp"
#include "apis/einsum.hpp"
#include "apis/hyperconnection.hpp"
#include "apis/gemm.hpp"
#include "apis/layout.hpp"
#include "apis/mega.hpp"
#include "apis/runtime.hpp"
#ifndef TORCH_EXTENSION_NAME
#define TORCH_EXTENSION_NAME _C
#endif
// ReSharper disable once CppParameterMayBeConstPtrOrRef
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.doc() = "DeepGEMM C++ library";
// TODO: make SM80 incompatible issues raise errors
deep_gemm::attention::register_apis(m);
deep_gemm::einsum::register_apis(m);
deep_gemm::hyperconnection::register_apis(m);
deep_gemm::gemm::register_apis(m);
deep_gemm::layout::register_apis(m);
deep_gemm::mega::register_apis(m);
deep_gemm::runtime::register_apis(m);
}

View File

@@ -0,0 +1,17 @@
#pragma once
#include <torch/version.h>
#include <cuda.h>
#include <cuda_runtime.h>
// `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1
#define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1))
// `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1
#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010)
// `cublasGetErrorString` is supported since CUDA Runtime API 11.4.2
#define DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE (CUDART_VERSION >= 11042)
// `CUBLASLT_MATMUL_DESC_FAST_ACCUM` and `CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET` are supported since CUDA Runtime API 11.8
#define DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE (CUDART_VERSION >= 11080)

View File

@@ -0,0 +1,109 @@
#pragma once
#include <cublasLt.h>
#include <exception>
#include <string>
#include <sstream>
#include "compatibility.hpp"
namespace deep_gemm {
class DGException final : public std::exception {
std::string message = {};
public:
explicit DGException(const char *name, const char* file, const int line, const std::string& error) {
message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error;
}
const char *what() const noexcept override {
return message.c_str();
}
};
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#endif
#ifndef DG_HOST_ASSERT
#define DG_HOST_ASSERT(cond) \
do { \
if (not (cond)) { \
throw DGException("Assertion", __FILE__, __LINE__, #cond); \
} \
} while (0)
#endif
#ifndef DG_HOST_UNREACHABLE
#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason))
#endif
#ifndef DG_NVRTC_CHECK
#define DG_NVRTC_CHECK(cmd) \
do { \
const auto e = (cmd); \
if (e != NVRTC_SUCCESS) { \
throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \
} \
} while (0)
#endif
#ifndef DG_CUDA_DRIVER_CHECK
#define DG_CUDA_DRIVER_CHECK(cmd) \
do { \
const auto e = (cmd); \
if (e != CUDA_SUCCESS) { \
std::stringstream ss; \
const char *name, *info; \
lazy_cuGetErrorName(e, &name), lazy_cuGetErrorString(e, &info); \
ss << static_cast<int>(e) << " (" << name << ", " << info << ")"; \
throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \
} \
} while (0)
#endif
#ifndef DG_CUDA_RUNTIME_CHECK
#define DG_CUDA_RUNTIME_CHECK(cmd) \
do { \
const auto e = (cmd); \
if (e != cudaSuccess) { \
std::stringstream ss; \
ss << static_cast<int>(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \
throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \
} \
} while (0)
#endif
#ifndef DG_CUBLASLT_CHECK
#if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE
inline const char* cublasGetStatusString(cublasStatus_t status) {
switch(status) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR";
default: return "Unknown cuBLAS error";
}
}
#endif
#define DG_CUBLASLT_CHECK(cmd) \
do { \
const auto e = (cmd); \
if (e != CUBLAS_STATUS_SUCCESS) { \
std::ostringstream ss; \
ss << static_cast<int>(e) << " (" << cublasGetStatusString(e) << ")"; \
throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \
} \
} while (0)
#endif
} // namespace deep_gemm

View File

@@ -0,0 +1,6 @@
#pragma once
// Just a wrapper for the `fmt` headers
#define FMT_HEADER_ONLY
#include <fmt/base.h>
#include <fmt/format.h>

View File

@@ -0,0 +1,39 @@
#pragma once
#include <string>
namespace deep_gemm {
static uint64_t fnv1a(const std::vector<char>& data, const uint64_t& seed) {
uint64_t h = seed;
const uint64_t prime = 0x100000001b3ull;
for (const char& c: data) {
h ^= static_cast<uint8_t>(c);
h *= prime;
}
return h;
}
static std::string get_hex_digest(const std::vector<char>& data) {
const auto state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull);
const auto state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull);
// Split-mix 64
const auto split_mix = [](uint64_t z) {
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull;
z = (z ^ (z >> 27)) * 0x94d049bb133111ebull;
return z ^ (z >> 31);
};
std::ostringstream oss;
oss << std::hex << std::setfill('0')
<< std::setw(16) << split_mix(state_0)
<< std::setw(16) << split_mix(state_1);
return oss.str();
}
static std::string get_hex_digest(const std::string& data) {
return get_hex_digest(std::vector<char>{data.begin(), data.end()});
}
} // namespace deep_gemm

View File

@@ -0,0 +1,119 @@
#pragma once
#include <cute/arch/mma_sm100_umma.hpp>
#include <torch/python.h>
#include "math.hpp"
#include "exception.hpp"
#include "../jit/device_runtime.hpp"
namespace deep_gemm {
// Major-ness stuffs
static void major_check(const torch::Tensor& t) {
const auto dim = t.dim();
DG_HOST_ASSERT(dim == 2 or dim == 3);
if (dim == 3)
DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1));
DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1);
}
static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) {
major_check(t);
return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN;
}
static void check_major_type_cd(const torch::Tensor& t) {
// NOTES: the library only supports row-major output layouts
major_check(t);
DG_HOST_ASSERT(t.stride(-1) == 1);
}
static bool fp8_requires_k_major() {
return device_runtime->get_arch_major() == 9;
}
// Tensor utils
template <int N>
static auto get_shape(const torch::Tensor& t) {
DG_HOST_ASSERT(t.dim() == N);
return [&t] <size_t... Is> (std::index_sequence<Is...>) {
return std::make_tuple(static_cast<int>(t.sizes()[Is])...);
}(std::make_index_sequence<N>());
}
static std::tuple<int, int> check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
auto [mn, k] = get_shape<2>(ab);
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
}
return std::make_tuple(mn, k);
}
static std::tuple<int, int, int> check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) {
auto [num_groups, mn, k] = get_shape<3>(ab);
if (ab.scalar_type() != torch::kFloat8_e4m3fn) {
DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10);
major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2);
}
return std::make_tuple(num_groups, mn, k);
}
// Recipe
static std::tuple<int, int, int>
get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) {
const auto arch_major = device_runtime->get_arch_major();
if (arch_major == 9) {
DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat);
return {1, 128, 128};
} else if (arch_major == 10) {
DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt);
return sfb_dtype == torch::kFloat ?
std::make_tuple(1, 128, 128): // Legacy format
std::make_tuple(1, 1, 128); // 1D1D kernels
}
DG_HOST_UNREACHABLE("Unknown recipe");
}
// SF layouts
static torch::Tensor check_sf_layout(const torch::Tensor& sf,
const int& mn, const int& k,
const int& gran_mn, const int& gran_k,
const std::optional<int>& num_groups,
const bool& tma_stride_check = false,
const bool& sm90_sfb_check = false,
const std::optional<torch::ScalarType>& type_check = std::nullopt) {
// Type check
if (type_check.has_value())
DG_HOST_ASSERT(sf.scalar_type() == type_check.value());
// Always do shape checks
const auto sf_dtype = sf.scalar_type();
DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt);
DG_HOST_ASSERT(sf.dim() == static_cast<int>(num_groups.has_value()) + 2);
if (num_groups.has_value())
DG_HOST_ASSERT(sf.size(-3) == num_groups.value());
DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn));
DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4)));
// TMA stride checks: TMA aligned and MN-major
if (tma_stride_check) {
if (num_groups.has_value())
DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1));
// Check contiguity in the MN direction
DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1);
DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size()));
}
// SM90 SFB must be contiguous, or contiguous after transposing the last two dimensions
if (sm90_sfb_check) {
if (num_groups.has_value())
DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1));
DG_HOST_ASSERT((sf.stride(-1) == 1 and sf.stride(-2) == sf.size(-1)) or
(sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1));
}
return sf;
}
} // namespace deep_gemm

View File

@@ -0,0 +1,27 @@
#pragma once
#include <functional>
#include <memory>
#define DG_DECLARE_STATIC_VAR_IN_CLASS(cls, name) decltype(cls::name) cls::name
namespace deep_gemm {
template <typename T>
class LazyInit {
public:
explicit LazyInit(std::function<std::shared_ptr<T>()> factory)
: factory(std::move(factory)) {}
T* operator -> () {
if (ptr == nullptr)
ptr = factory();
return ptr.get();
}
private:
std::shared_ptr<T> ptr;
std::function<std::shared_ptr<T>()> factory;
};
} // namespace deep_gemm

View File

@@ -0,0 +1,29 @@
// TODO: merge this file with `math.cuh` (the device part)
#pragma once
#include <torch/python.h>
#include "exception.hpp"
namespace deep_gemm {
// TODO: use `torch::kFloat4_e2m1fn_x2`
constexpr auto kPackedFP4 = torch::kInt8;
template <typename T>
static T ceil_div(const T& a, const T& b) {
return (a + b - 1) / b;
}
template <typename T>
static constexpr T align(const T& a, const T& b) {
return ceil_div(a, b) * b;
}
static int get_tma_aligned_size(const int& x, const int& element_size) {
constexpr int kNumTMAAlignmentBytes = 16;
DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0);
return align(x, kNumTMAAlignmentBytes / element_size);
}
} // namespace deep_gemm

View File

@@ -0,0 +1,128 @@
#pragma once
#include <array>
#include <filesystem>
#include <functional>
#include <random>
#include <string>
#include <memory>
#include <unistd.h>
#include "exception.hpp"
#include "format.hpp"
namespace deep_gemm {
// ReSharper disable once CppNotAllPathsReturnValue
template <typename dtype_t>
static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) {
const auto c_str = std::getenv(name.c_str());
if (c_str == nullptr)
return default_value;
// Read the env and convert to the desired type
if constexpr (std::is_same_v<dtype_t, std::string>) {
return std::string(c_str);
} else if constexpr (std::is_same_v<dtype_t, int>) {
int value;
std::sscanf(c_str, "%d", &value);
return value;
} else {
DG_HOST_ASSERT(false and "Unexpected type");
}
}
static std::tuple<int, std::string> call_external_command(std::string command) {
command = command + " 2>&1";
const auto deleter = [](FILE* f) { if (f) pclose(f); };
std::unique_ptr<FILE, decltype(deleter)> pipe(popen(command.c_str(), "r"), deleter);
DG_HOST_ASSERT(pipe != nullptr);
std::array<char, 512> buffer;
std::string output;
while (fgets(buffer.data(), buffer.size(), pipe.get()))
output += buffer.data();
const auto status = pclose(pipe.release());
// NOTES: if the child was killed by a signal (e.g., SIGINT from Ctrl+C),
// WEXITSTATUS would incorrectly return 0. Treat signal death as failure.
const auto exit_code = WIFEXITED(status) ? WEXITSTATUS(status) : 128 + WTERMSIG(status);
return {exit_code, output};
}
static std::vector<std::filesystem::path> collect_files(const std::filesystem::path& root) {
std::vector<std::filesystem::path> files;
std::function<void(const std::filesystem::path&)> impl;
impl = [&](const std::filesystem::path& dir) {
for (const auto& entry: std::filesystem::directory_iterator(dir)) {
if (entry.is_directory()) {
impl(entry.path());
} else if (entry.is_regular_file() and entry.path().extension() == ".cuh") {
files.emplace_back(entry.path());
}
}
};
impl(root);
// Be consistent
std::sort(files.begin(), files.end());
return files;
}
static std::filesystem::path make_dirs(const std::filesystem::path& path) {
// OK if existed
std::error_code capture;
const bool created = std::filesystem::create_directories(path, capture);
if (not (created or capture.value() == 0)) {
DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}",
path.c_str(), created, capture.value()));
}
if (created and get_env<int>("DG_JIT_DEBUG"))
printf("Create directory: %s\n", path.c_str());
return path;
}
static std::string get_uuid() {
static std::random_device rd;
static std::mt19937 gen([]() {
return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count();
}());
static std::uniform_int_distribution<uint32_t> dist;
std::stringstream ss;
ss << getpid() << "-"
<< std::hex << std::setfill('0')
<< std::setw(8) << dist(gen) << "-"
<< std::setw(8) << dist(gen) << "-"
<< std::setw(8) << dist(gen);
return ss.str();
}
static void safe_remove_all(const std::filesystem::path& path) {
std::error_code ec;
if (not std::filesystem::exists(path, ec) or ec)
return;
// A single file
if (not std::filesystem::is_directory(path, ec) or ec) {
std::filesystem::remove(path, ec);
return;
}
// Remove directory
auto it = std::filesystem::directory_iterator(path,
std::filesystem::directory_options::skip_permission_denied, ec);
for (auto end = std::filesystem::directory_iterator(); it != end and not ec;) {
const auto entry_path = it->path();
// Increase firstly to avoid failures
it.increment(ec);
if (ec)
break;
// Recursively clean
safe_remove_all(entry_path);
}
std::filesystem::remove(path, ec);
}
} // deep_gemm

View File

@@ -0,0 +1,126 @@
import os
import subprocess
import torch
# Set some default environment provided at setup
try:
# noinspection PyUnresolvedReferences
from .envs import persistent_envs
for key, value in persistent_envs.items():
if key not in os.environ:
os.environ[key] = value
except ImportError:
pass
# Configs
from . import _C
from ._C import (
set_num_sms,
get_num_sms,
set_tc_util,
get_tc_util,
set_ignore_compile_dims,
set_block_size_multiple_of,
set_pdl,
get_pdl,
)
# cuBLASLt Kernels
from ._C import (
cublaslt_gemm_nt, cublaslt_gemm_nn,
cublaslt_gemm_tn, cublaslt_gemm_tt,
)
try:
# DeepGEMM Kernels
from ._C import (
# FP8 FP4 GEMMs
fp8_fp4_gemm_nt, fp8_fp4_gemm_nn,
fp8_fp4_gemm_tn, fp8_fp4_gemm_tt,
m_grouped_fp8_fp4_gemm_nt_contiguous,
m_grouped_fp8_fp4_gemm_nn_contiguous,
m_grouped_fp8_fp4_gemm_nt_masked,
# FP8 GEMMs
fp8_gemm_nt, fp8_gemm_nn,
fp8_gemm_tn, fp8_gemm_tt,
fp8_gemm_nt_skip_head_mid,
m_grouped_fp8_gemm_nt_contiguous,
m_grouped_fp8_gemm_nn_contiguous,
m_grouped_fp8_gemm_nt_masked,
k_grouped_fp8_gemm_nt_contiguous,
k_grouped_fp8_gemm_tn_contiguous,
# BF16 GEMMs
bf16_gemm_nt, bf16_gemm_nn,
bf16_gemm_tn, bf16_gemm_tt,
m_grouped_bf16_gemm_nt_contiguous,
m_grouped_bf16_gemm_nn_contiguous,
m_grouped_bf16_gemm_nt_masked,
k_grouped_bf16_gemm_tn_contiguous,
# Einsum kernels
einsum,
fp8_einsum,
# Attention kernels
fp8_fp4_mqa_logits,
get_paged_mqa_logits_metadata,
fp8_fp4_paged_mqa_logits,
# Attention kernels (legacy)
fp8_mqa_logits,
fp8_paged_mqa_logits,
# Hyperconnection kernels
tf32_hc_prenorm_gemm,
# Layout kernels
transform_sf_into_required_layout,
)
# Some alias for legacy supports
# TODO: remove these later
fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked
bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked
except ImportError:
# Expected behavior for CUDA runtime version before 12.1
pass
# Mega kernels
from .mega import (
SymmBuffer,
get_symm_buffer_for_mega_moe,
transform_weights_for_mega_moe,
fp8_fp4_mega_moe,
)
# Some utils
from . import testing
from . import utils
from .utils import *
# Legacy Triton kernels for A100
try:
from . import legacy
except Exception as e:
print(f'Failed to load legacy DeepGEMM A100 Triton kernels: {e}')
# Initialize CPP modules
def _find_cuda_home() -> str:
# TODO: reuse PyTorch API later
# For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks
cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
if cuda_home is None:
# noinspection PyBroadException
try:
with open(os.devnull, 'w') as devnull:
nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n')
cuda_home = os.path.dirname(os.path.dirname(nvcc))
except Exception:
cuda_home = '/usr/local/cuda'
if not os.path.exists(cuda_home):
cuda_home = None
assert cuda_home is not None
return cuda_home
_C.init(
os.path.dirname(os.path.abspath(__file__)), # Library root directory path
_find_cuda_home() # CUDA home
)
__version__ = '2.5.0'

View File

@@ -0,0 +1,83 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/layout/sym_buffer.cuh>
#include <deep_gemm/layout/mega_moe.cuh>
namespace deep_gemm::comm {
CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() {
// Perform cluster_sync with `barrier.cluster.arrive.relaxed`
// This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee
cute::cluster_arrive_relaxed();
cute::cluster_wait();
}
template <uint32_t kNumSMs, uint32_t kGridSyncIndex = 0, typename sync_scope_t>
CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace,
const uint32_t& sm_idx, const uint32_t& thread_idx,
const sync_scope_t& sync_scope) {
// NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()`
static constexpr uint32_t kFinishSumTag = 0x80000000u;
sync_scope();
if (thread_idx == 0) {
const auto count_ptr = workspace.get_grid_sync_count_ptr<kGridSyncIndex>();
const auto old_value = ptx::atomic_add_rel(
count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1);
uint32_t new_value;
do {
new_value = ptx::ld_acq(count_ptr);
} while (((new_value ^ old_value) & kFinishSumTag) == 0);
}
sync_scope();
}
template <uint32_t kNumRanks, uint32_t kNumSMs, uint32_t kNumThreads, uint32_t kGridSyncIndex, uint32_t kTag, typename sync_scope_t>
CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace,
const layout::SymBuffer<kNumRanks>& sym_buffer,
const uint32_t& sm_idx, const uint32_t& thread_idx,
const sync_scope_t& sync_scope,
const bool& sync_prologue = true,
const bool& sync_epilogue = true) {
DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads");
// Grid sync before NVLink signaling
if (sync_prologue)
grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
// NVLink cross-rank barrier, only SM 0 participates
if (sm_idx == 0) {
auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr();
const auto status = (*counter_ptr) & 3;
const auto signal_phase = status & 1, signal_sign = status >> 1;
auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase);
// Send signals to remote ranks
if (thread_idx < kNumRanks)
ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1);
sync_scope();
// Update status and wait arrival (with 30s timeout, at 2 GHz)
constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll;
if (thread_idx == 0) {
ptx::red_add(counter_ptr, 1);
const int target = signal_sign ? 0 : static_cast<int>(kNumRanks);
const auto start_clock = clock64();
while (ptx::ld_acq_sys(signal_ptr) != target) {
if (clock64() - start_clock >= kNumTimeoutCycles) {
printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n",
sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag);
DG_DEVICE_ASSERT(false and "NVLink barrier timeout");
}
}
}
}
// Grid sync after NVLink completion
if (sync_epilogue)
grid_sync<kNumSMs, kGridSyncIndex>(workspace, sm_idx, thread_idx, sync_scope);
}
} // namespace deep_gemm::comm

View File

@@ -0,0 +1,18 @@
#pragma once
#include <cutlass/detail/helper_macros.hpp>
#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__)
#define DG_IN_CUDA_COMPILATION
#endif
#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__))
#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__
#define CUTLASS_DEVICE_NOINLINE __device__
#elif defined(__CUDACC_RTC__)
#define CUTLASS_HOST_DEVICE_NOINLINE __device__
#define CUTLASS_DEVICE_NOINLINE __device__
#else
#define CUTLASS_HOST_DEVICE_NOINLINE
#define CUTLASS_DEVICE_NOINLINE
#endif

View File

@@ -0,0 +1,50 @@
#pragma once
#include <cute/int_tuple.hpp>
namespace cute {
struct ignore_t {
template <typename T>
constexpr const ignore_t& operator=(T&&) const noexcept {
return *this;
}
};
inline constexpr ignore_t ignore{};
} // namespace cute
#define CUTE_TIE_CONCAT_IMPL(A, B) A##B
#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B)
#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N
#define CUTE_TIE_COUNT_ARGS(...) \
CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)
#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get<I>(TUPLE)
#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get<I>(TUPLE)
#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1);
#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2);
#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3);
#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4);
#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5);
#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
CUTE_TIE_OP_DECL, \
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
__VA_ARGS__ \
)
#define CUTE_TIE(TUPLE_EXPR, ...) \
do { \
auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \
CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \
CUTE_TIE_OP_ASSIGN, \
CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \
__VA_ARGS__ \
); \
} while (0)

View File

@@ -0,0 +1,27 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm {
struct EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
return n_idx;
}
};
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
struct EpilogueHeadSplits: EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
__device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) {
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0
and kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@@ -0,0 +1,43 @@
#pragma once
#include <cuda/std/cstdint>
#include <deep_gemm/common/compile.cuh>
#ifdef __CLION_IDE__
CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) {
asm volatile("trap;");
}
#define printf host_device_printf
#endif
#ifndef DG_DEVICE_ASSERT
#define DG_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while (0)
#endif
#ifndef DG_TRAP_ONLY_DEVICE_ASSERT
#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) \
asm("trap;"); \
} while (0)
#endif
#ifndef DG_STATIC_ASSERT
#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__)
#endif
#ifndef DG_UNIFIED_ASSERT
#ifdef DG_IN_CUDA_COMPILATION
#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond)
#else
#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond)
#endif
#endif

View File

@@ -0,0 +1,149 @@
#pragma once
#include <cuda/std/cstdint>
#include <deep_gemm/common/compile.cuh>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::math {
/// Pointer operations
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) {
return reinterpret_cast<dtype_t*>(static_cast<uint8_t*>(ptr) + num_bytes);
}
/// Math functions
template <typename T>
CUTLASS_HOST_DEVICE T ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) {
return (a + b - 1) / b;
}
template <typename T, bool kDoCeilAlignment = true>
CUTLASS_HOST_DEVICE T align(T a, T b) {
return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b;
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) {
return constexpr_ceil_div(a, b) * b;
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) {
return b == 0 ? a : constexpr_gcd(b, a % b);
}
template <typename T>
CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) {
return a < b ? a : b;
}
template <typename T>
CUTLASS_DEVICE void swap(T& a, T& b) {
T temp = a;
a = b;
b = temp;
}
#ifdef DG_IN_CUDA_COMPILATION
CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) {
#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)
return __ffma2_rn(a, b, c);
#else
return make_float2(
__fmaf_rn(a.x, b.x, c.x),
__fmaf_rn(a.y, b.y, c.y)
);
#endif
}
CUTLASS_HOST_DEVICE float fast_rcp(const float& x) {
float ret;
asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x));
return ret;
}
/// Casting
template <typename old_t>
CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) {
auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast<float*>(&x), *reinterpret_cast<float*>(&y)});
return *reinterpret_cast<int*>(&bf16x2);
}
CUTLASS_DEVICE float fast_pow2(const int& x) {
uint32_t bits_x = (x + 127) << 23;
return *reinterpret_cast<float*>(&bits_x);
}
CUTLASS_DEVICE int fast_log2_ceil(float x) {
const auto bits = *reinterpret_cast<uint32_t*>(&x);
const auto exp = bits >> 23;
const auto man = bits & ((1 << 23) - 1);
return exp - 127 + (man != 0);
}
template <bool kUseUE8M0 = true>
CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) {
DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0");
const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0};
const auto scaled = __fmul2_rn(amax, finfo_factor);
const auto exp_x = fast_log2_ceil(scaled.x);
const auto exp_y = fast_log2_ceil(scaled.y);
sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x);
sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y);
}
/// Reduction
CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) {
#pragma unroll
for (uint32_t offset = 1; offset < 32; offset <<= 1) {
const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset);
if (lane_idx >= offset)
value += synced;
}
return value;
}
// Operation functors
template <typename T> struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } };
template <typename T> struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } };
template <typename T> struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } };
// Unified reduction function
template <uint32_t kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
CUTLASS_DEVICE T warp_reduce(T value, Op op) {
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
} else {
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
}
return value;
}
// Convenience aliases
template <uint32_t kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
CUTLASS_DEVICE T warp_reduce_sum(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}
#endif
} // namespace deep_gemm

View File

@@ -0,0 +1,44 @@
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda/std/cstdint>
#include <cuda/std/utility>
#include <deep_gemm/common/utils.cuh>
// Operation functors
template <typename T> struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } };
template <typename T> struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } };
template <typename T> struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } };
template <typename T> struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } };
template <typename T> struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } };
// Unified reduction function
template <int kNumLanesPerGroup, bool kIntergroupReduce, typename T, typename Op>
__forceinline__ __device__ T warp_reduce(T value, Op op) {
DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or
kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1,
"Invalid number of lanes");
constexpr uint32_t mask = 0xffffffff;
if constexpr (kIntergroupReduce) {
if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1));
if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16));
} else {
if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16));
if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8));
if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4));
if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2));
if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1));
}
return value;
}
// Convenience aliases
template <int kNumLanesPerGroup = 32, bool kIntergroupReduce = false, typename T>
__forceinline__ __device__ T warp_reduce_sum(T value) {
return warp_reduce<kNumLanesPerGroup, kIntergroupReduce, T>(value, ReduceSum<T>{});
}

View File

@@ -0,0 +1,288 @@
#pragma once
#include <deep_gemm/common/types.hpp>
#include <deep_gemm/common/utils.cuh>
namespace deep_gemm {
enum class IndexType {
MN,
K,
SF_K,
};
template <GemmType kGemmType, uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t kNumSMs, bool kIsMulticastOnA>
static constexpr uint32_t get_num_1d_blocks_per_group() {
// Select the best from candidates
uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits<uint32_t>::max();
for (const auto& candidate: {8u, 16u}) {
const auto& usage = kIsMulticastOnA ?
candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N
candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M
if (usage < min_usage)
min_usage = usage, num_best_blocks = candidate;
}
return num_best_blocks;
}
#pragma clang diagnostic push
#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init"
template <GemmType kGemmType,
uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t kNumGroups,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
uint32_t SF_K_ALIGNMENT = 512u, // for k-grouped GEMM only: 128 (SM90 float SF) or 512 (SM100 UE8M0 SF)
uint32_t kNum1DBlocksPerGroup = get_num_1d_blocks_per_group<kGemmType, BLOCK_M, BLOCK_N, kNumSMs, kIsMulticastOnA>()>
struct Scheduler {
int current_iter = -1;
// Block configs
uint32_t num_blocks;
uint32_t num_m_blocks;
uint32_t num_n_blocks;
// For SM90 multicast checks
uint32_t num_blocks_in_group;
bool is_peer_cta_alive = true;
// For grouped GEMM
int* grouped_layout;
uint32_t current_group_idx = 0;
// Only used for masked layout
uint32_t current_m_cumsum = 0;
// Only used for countiguous psum layout
uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0;
// Only used for k-grouped layout
uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0;
uint32_t next_group_idx, next_shape_k;
// Only used for k-grouped gemm
__device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const {
for (; group_idx < kNumGroups; ++ group_idx) {
shape_k = __ldg(grouped_layout + group_idx);
if (shape_k > 0)
break;
}
}
// ReSharper disable once CppPossiblyUninitializedMember
__device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k,
int* grouped_layout = nullptr) {
num_m_blocks = ceil_div(shape_m, BLOCK_M);
num_n_blocks = ceil_div(shape_n, BLOCK_N);
current_shape_k = shape_k;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
num_blocks = num_m_blocks * num_n_blocks;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
num_blocks = num_m_blocks * num_n_blocks;
this->grouped_layout = grouped_layout;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
this->grouped_layout = grouped_layout;
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
this->grouped_layout = grouped_layout;
current_psum_m = __ldg(grouped_layout);
num_m_blocks = ceil_div(current_psum_m, BLOCK_M);
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
this->grouped_layout = grouped_layout;
get_next_k_group(current_group_idx, current_shape_k);
next_group_idx = current_group_idx + 1;
get_next_k_group(next_group_idx, next_shape_k);
}
}
__device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) {
DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size");
// Swizzle for better L2 usages
const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks;
const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks;
const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup;
const auto& group_idx = block_idx / num_blocks_per_group;
auto first_block_idx = group_idx * kNum1DBlocksPerGroup;
auto in_group_idx = block_idx % num_blocks_per_group;
num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx);
// Fix unaligned TMA multicast
// NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast
// while SM100 uses 2-CTA, which can not be dynamically disabled
#if __CUDA_ARCH__ < 1000
if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) {
if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) {
num_blocks_in_group = num_blocks_in_group ^ 1;
} else {
in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks;
first_block_idx += num_blocks_in_group ^ 1;
num_blocks_in_group = 1;
}
}
#endif
// Convert to final M/N block indices
// `kIsMulticastOnA == true` leads to groups on N
if constexpr (kIsMulticastOnA) {
m_block_idx = in_group_idx / num_blocks_in_group;
n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
} else {
m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group;
n_block_idx = in_group_idx / num_blocks_in_group;
}
}
template <bool kWithGroupOffset, IndexType kIndexType = IndexType::MN>
__device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size,
const uint32_t& block_idx, const uint32_t& m_block_idx = 0) {
if constexpr (kGemmType == GemmType::Normal) {
return block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
const auto offset = kWithGroupOffset ? current_group_idx : 0;
return offset * shape_dim + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
auto offset = 0;
if constexpr (kWithGroupOffset) {
if constexpr (kIndexType == IndexType::MN)
offset = current_group_idx * shape_dim;
else if constexpr (kIndexType == IndexType::K)
offset = current_k_cumsum;
else if constexpr (kIndexType == IndexType::SF_K)
offset = current_sf_k_cumsum;
}
return offset + block_idx * block_size;
} else if constexpr (kGemmType == GemmType::Batched) {
// Ignore kWithGroupOffset, and apply offset for IndexType::SF_K
const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0;
return offset * shape_dim + block_idx * block_size;
}
}
__device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) {
const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x;
if constexpr (kGemmType == GemmType::MGroupedMasked) {
while (true) {
// End of the task
if (current_group_idx == kNumGroups)
return false;
// Within current group
num_m_blocks = ceil_div(static_cast<uint32_t>(__ldg(grouped_layout + current_group_idx)), BLOCK_M);
const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks;
if (next_block_idx < current_m_block_cumsum * num_n_blocks)
break;
// Move to check the next group
current_group_idx ++, current_m_cumsum = current_m_block_cumsum;
}
get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx);
} else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) {
while (true) {
// Within current group
if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks)
break;
// Move to check the next group
if (++ current_group_idx == kNumGroups)
return false;
// NOTES: `num_m_blocks` varies with the increase of the group index
last_psum_m = align(current_psum_m, 128u);
current_psum_m = __ldg(grouped_layout + current_group_idx);
current_m_block_cumsum += num_m_blocks;
num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M);
}
get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx);
// NOTES: `last_psum_m` is aligned with 128
m_block_idx += last_psum_m / BLOCK_M;
DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M");
} else if constexpr (kGemmType == GemmType::KGroupedContiguous) {
while (true) {
// End of the task
if (current_group_idx == kNumGroups)
return false;
// Within current group
if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks)
break;
// Move to check the next group
current_k_cumsum += current_shape_k;
current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT);
current_num_valid_groups ++;
current_group_idx = next_group_idx ++;
current_shape_k = next_shape_k;
get_next_k_group(next_group_idx, next_shape_k);
}
get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx);
} else if constexpr (kGemmType == GemmType::Batched) {
if (next_block_idx >= num_blocks * kNumGroups)
return false;
current_group_idx = next_block_idx / num_blocks;
const auto& block_idx = next_block_idx - current_group_idx * num_blocks;
if constexpr (kIsMulticastOnA) {
m_block_idx = block_idx / num_n_blocks;
n_block_idx = block_idx % num_n_blocks;
} else {
m_block_idx = block_idx % num_m_blocks;
n_block_idx = block_idx / num_m_blocks;
}
} else {
if (next_block_idx >= num_blocks)
return false;
// For SM90 only
// NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned
is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass)
num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass)
(next_block_idx ^ 1) < num_blocks; // Peer CTA in bound
get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx);
}
return true;
}
// For SM90 only
__device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const {
if (num_blocks_in_group == 1)
return false;
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or
kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) {
return true;
} else {
DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type");
if constexpr (kIsMulticastOnA) {
return true;
} else {
const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M);
const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M);
return group_idx == peer_group_idx;
}
}
}
// For SM90 only
// ReSharper disable once CppNotAllPathsReturnValue
__device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const {
if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) {
return true;
} else if constexpr (kGemmType == GemmType::MGroupedContiguous) {
return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0;
} else if constexpr (kGemmType == GemmType::MGroupedMasked) {
return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx);
} else {
// Unreachable
DG_TRAP_ONLY_DEVICE_ASSERT(false);
}
}
};
#pragma clang diagnostic pop
} // namespace deep_gemm

View File

@@ -0,0 +1,266 @@
#pragma once
#include <cute/atom/mma_traits_sm100.hpp>
#include <cute/arch/mma_sm100_umma.hpp>
#include <cute/arch/tmem_allocator_sm100.hpp>
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_utils.cuh>
namespace deep_gemm::sm100 {
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr,
uint32_t stride_byte_offset, uint32_t leading_byte_offset) {
cute::UMMA::SmemDescriptor desc;
// Set the version for SM100
desc.version_ = 1;
// Legacy mode
desc.lbo_mode_ = 0;
// Layout
desc.layout_type_ = static_cast<uint8_t>(layout);
// Start address
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
// Base offset
desc.base_offset_ = 0;
// SBO and LBO
desc.stride_byte_offset_ = stride_byte_offset >> 4;
desc.leading_byte_offset_ = leading_byte_offset >> 4;
return desc;
}
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) {
// NOTES: the UTCCP layout is K-major by default
// Atom size: 8 x 128 bits
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero
return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0);
}
__device__ __forceinline__
void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) {
const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr);
desc.start_address_ = static_cast<uint16_t>(uint_ptr >> 4);
}
__device__ __forceinline__
static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) {
return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16;
}
// ReSharper disable once CppNotAllPathsReturnValue
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, bool kUseBase32, typename dtype_t>
constexpr static cute::UMMA::LayoutType to_umma_layout_type() {
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
kSwizzleMode == 32 or kSwizzleMode == 64 or
kSwizzleMode == 128, "Invalid swizzling mode");
// A special case
if constexpr ((cute::is_same_v<dtype_t, float> and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) {
DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base");
return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B;
}
// Normal cases
if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE;
if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B;
if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B;
if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B;
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
constexpr uint32_t get_umma_desc_stride_k() {
return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) {
return base + (((offset + k_idx * get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, bool kUseBase32 = false, typename dtype_t>
__device__ __forceinline__
cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
const uint32_t stride_k = get_umma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
const auto& layout_type = to_umma_layout_type<kMajorMode, kSwizzleMode, kUseBase32, dtype_t>();
const auto& num_non_contiguous = 128 / get_atom_base(layout_type);
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)`
// also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
const uint32_t leading_byte_offset = 0;
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
} else {
constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
// Must have no in-atom MN-idx
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
if constexpr (kSwizzleMode == 16)
swap(stride_byte_offset, leading_byte_offset);
return make_smem_desc(layout_type,
base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k,
stride_byte_offset, leading_byte_offset);
}
}
__device__ __forceinline__
uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) {
desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id;
return static_cast<uint64_t>(static_cast<uint32_t>(desc)) << 32;
}
template <uint32_t kNumCols>
__device__ constexpr uint32_t get_num_aligned_tmem_cols() {
DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
if (kNumCols <= 32) return 32;
if (kNumCols <= 64) return 64;
if (kNumCols <= 128) return 128;
if (kNumCols <= 256) return 256;
return 512;
}
__device__ __forceinline__ void tcgen05_before_thread_sync() {
asm volatile("tcgen05.fence::before_thread_sync;");
}
__device__ __forceinline__ void tcgen05_after_thread_sync() {
asm volatile("tcgen05.fence::after_thread_sync;");
}
__device__ __forceinline__
void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) {
uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr);
uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier);
asm volatile(
"cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n"
:
: "r"(smem_addr), "l"(desc_ptr), "r"(col_idx),
"r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w),
"r"(mbarrier_addr), "l"(cache_hint)
: "memory"
);
}
// UMMA versions with relaxed assertions
struct SM100_MMA_F16BF16_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_F16BF16_2x1SM_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
struct SM100_MMA_MXF8F6F4_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_MXF8F6F4_2x1SM_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc,
uint32_t const& tmem_sfa,
uint32_t const& tmem_sfb) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c),
"r"(tmem_sfa), "r"(tmem_sfb));
}
};
struct SM100_MMA_F16BF16_WS_SS {
__device__ static void
fma(uint64_t const& desc_a,
uint64_t const& desc_b,
uint32_t const& tmem_c,
uint32_t const& scale_c,
uint64_t const& desc) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t"
"}\n"
:: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast<uint32_t>(desc >> 32)), "r"(scale_c));
}
};
} // namespace `deep_gemm::sm100`

View File

@@ -0,0 +1,332 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/mma_sm90_desc.hpp>
#include <cute/arch/mma_sm90_gmma.hpp>
#include <cute/arch/mma_sm90_gmma_ext.hpp>
#include <cute/arch/mma_sm100_desc.hpp>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
#include <deep_gemm/common/tma_utils.cuh>
namespace deep_gemm::sm90 {
template <int N_, typename MMA>
struct FP8MMA {
template <size_t ...Idx>
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 32;
static constexpr int kNumAccum = M * N / 128;
};
template <int N>
struct FP8MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN();
if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN();
}
static constexpr auto select_type() {
return FP8MMA<N, decltype(select_mma())>();
}
using type = decltype(select_type());
};
template <int N_, typename MMA>
struct BF16MMA {
template <size_t ...Idx>
__forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
__forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 16;
static constexpr int kNumAccum = M * N / 128;
};
template <cute::UMMA::Major kMajor>
constexpr cute::SM90::GMMA::Major to_sm90_major() {
DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness");
return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN;
}
template <int N,
cute::UMMA::Major kMajorA = cute::UMMA::Major::K,
cute::UMMA::Major kMajorB = cute::UMMA::Major::K>
struct BF16MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
constexpr auto kGMMAMajorA = to_sm90_major<kMajorA>();
constexpr auto kGMMAMajorB = to_sm90_major<kMajorB>();
if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS<kGMMAMajorA, kGMMAMajorB>();
}
static constexpr auto select_type() {
return BF16MMA<N, decltype(select_mma())>();
}
using type = decltype(select_type());
};
template <int N_, typename MMA>
struct TF32MMARS {
template <size_t ...Idx>
__forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence<Idx...>) {
using namespace cute::SM90::GMMA;
MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero));
}
__forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) {
call_fma_impl(reinterpret_cast<uint32_t*>(a), desc_b, d, scale_d, cute::make_index_sequence<N_/2>{});
}
static constexpr int M = 64;
static constexpr int N = N_;
static constexpr int K = 8;
static constexpr int kNumAccum = M * N / 128;
};
template <int N, bool kUseRS = true>
struct TF32MMASelector {
static constexpr auto select_mma() {
using namespace cute::SM90::GMMA;
if constexpr (kUseRS) {
if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN();
if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN();
if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN();
if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN();
if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN();
if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN();
DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N");
}
}
static constexpr auto select_type() {
if constexpr (kUseRS) {
return TF32MMARS<N, decltype(select_mma())>();
} else {
DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now");
}
}
using type = decltype(select_type());
};
template <typename dtype_t>
struct SM90_U32x2_STSM_N {
__device__ __forceinline__ static void
copy(dtype_t src_0, dtype_t src_1, void* smem_dst) {
const uint32_t src[2] = {*reinterpret_cast<uint32_t*>(&src_0), *reinterpret_cast<uint32_t*>(&src_1)};
asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n"
:: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1]));
}
};
struct SM90_U32x2_LDSM_N {
__device__ __forceinline__ static void
copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n"
: "=r"(dst_0), "=r"(dst_1)
: "l"(__cvta_generic_to_shared(smem_src)));
}
};
struct SM90_U32x4_LDSM_N {
__device__ __forceinline__ static void
copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) {
asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n"
: "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3)
: "l"(__cvta_generic_to_shared(smem_src)));
}
};
__forceinline__ __device__ void warpgroup_arrive() {
asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory");
}
__forceinline__ __device__ void warpgroup_commit_batch() {
asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory");
}
__forceinline__ __device__ void warpgroup_fence_operand(float& reg) {
asm volatile("" : "+f"(reg) :: "memory");
}
template <int N>
__forceinline__ __device__ void warpgroup_wait() {
DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]");
asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory");
}
template <class PointerType>
__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type,
const int& leading_byte_offset = 0,
const int& stride_byte_offset = 1024) {
// NOTES: the default LBO and SBO are for K-major types
cute::GmmaDescriptor desc;
const auto& uint_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
desc.bitfield.start_address_ = uint_ptr >> 4;
desc.bitfield.layout_type_ = layout_type;
desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4;
desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4;
desc.bitfield.base_offset_ = 0;
return desc;
}
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
constexpr uint32_t get_inner_block_atom_size() {
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
constexpr uint32_t get_gmma_desc_stride_k() {
return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
}
// ReSharper disable once CppNotAllPathsReturnValue
template <cute::UMMA::Major kMajorMode, uint32_t kSwizzleMode, typename dtype_t>
constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() {
DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or
kSwizzleMode == 32 or kSwizzleMode == 64 or
kSwizzleMode == 128, "Invalid swizzling mode");
// Normal cases
if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE;
if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32;
if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64;
if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128;
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) {
return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>()) * static_cast<uint32_t>(sizeof(dtype_t))) >> 4u);
}
template <cute::UMMA::Major kMajorMode, uint32_t BLOCK_MN, uint32_t BLOCK_K, uint32_t kSwizzleMode, typename dtype_t>
__device__ __forceinline__
cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) {
const uint32_t stride_k = get_gmma_desc_stride_k<kMajorMode, BLOCK_MN, kSwizzleMode, dtype_t>();
const auto& layout_type = to_gmma_layout_type<kMajorMode, kSwizzleMode, dtype_t>();
constexpr uint32_t num_non_contiguous = 128 / 16;
if constexpr (kMajorMode == cute::UMMA::Major::K) {
// NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128
DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value");
// Atom size: 8 x `kSwizzleMode` (in bytes, on K)
// {SBO, LBO} means the byte stride between atoms on {MN, K}
// NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0
const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t);
const uint32_t leading_byte_offset = 0;
return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
leading_byte_offset, stride_byte_offset);
} else {
constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size<BLOCK_MN, kSwizzleMode, dtype_t>();
// Must have no in-atom MN-idx
// NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time
DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0);
DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling");
// Atom size: `kSwizzleMode` (in bytes, on MN) x 8
// NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving
// {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling
// {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling
uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t);
uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t);
if constexpr (kSwizzleMode == 16)
swap(stride_byte_offset, leading_byte_offset);
return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast<uint32_t>(layout_type),
leading_byte_offset, stride_byte_offset);
}
}
} // namespace `deep_gemm::sm90`

View File

@@ -0,0 +1,92 @@
#pragma once
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/arch/copy_sm100_tma.hpp>
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::tma {
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
constexpr uint32_t get_inner_block_atom_size() {
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
}
template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
uint32_t kSwizzleMode,
typename dtype_t, bool kIs3DTMA = false>
CUTLASS_DEVICE void
copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx,
const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) {
DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
if constexpr (not kIs3DTMA) {
if (num_tma_multicast == 1) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
} else {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
// 2-CTA function will send signals to the leader CTA only
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
if (cute::block_rank_in_cluster() == 0) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
}
#endif
}
} else {
if (num_tma_multicast == 1) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
} else {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
// 2-CTA function will send signals to the leader CTA only
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
if (cute::block_rank_in_cluster() == 0) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
}
#endif
}
}
}
} // namespace deep_gemm::tma

View File

@@ -0,0 +1,116 @@
#pragma once
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/arch/copy_sm100_tma.hpp>
#include <cutlass/arch/barrier.h>
namespace deep_gemm {
template <uint32_t BLOCK_INNER, uint32_t kSwizzleMode, typename dtype_t>
constexpr uint32_t get_inner_block_atom_size() {
return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t);
}
template <uint32_t BLOCK_INNER, uint32_t BLOCK_OUTER,
uint32_t kSwizzleMode,
typename dtype_t, bool kIs3DTMA = false>
__device__ __forceinline__ void
tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr,
dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx,
const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) {
DG_STATIC_ASSERT(static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL) ==
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint");
constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size<BLOCK_INNER, kSwizzleMode, dtype_t>();
if constexpr (not kIs3DTMA) {
if (num_tma_multicast == 1) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
} else {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
// 2-CTA function will send signals to the leader CTA only
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
if (cute::block_rank_in_cluster() == 0) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx);
}
}
#endif
}
} else {
if (num_tma_multicast == 1) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
} else {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000))
// 2-CTA function will send signals to the leader CTA only
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
#elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900))
if (cute::block_rank_in_cluster() == 0) {
#pragma unroll
for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) {
cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast<uint64_t*>(barrier_ptr),
(1 << num_tma_multicast) - 1, static_cast<uint64_t>(cute::TMA::CacheHintSm90::EVICT_NORMAL),
smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM,
inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx);
}
}
#endif
}
}
}
// Tensormap related
__device__ __forceinline__ void tensor_map_release_cta() {
asm volatile ("fence.proxy.tensormap::generic.release.cta;");
}
__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) {
auto gmem_int_desc = reinterpret_cast<uint64_t>(gmem_desc_ptr);
asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory");
}
__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) {
auto smem_int_desc = static_cast<uint32_t>(__cvta_generic_to_shared(smem_desc));
const auto new_int64_addr = reinterpret_cast<uint64_t>(new_addr);
asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr));
}
__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) {
auto smem_int_desc = __cvta_generic_to_shared(smem_desc);
asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim));
#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3)))
asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride));
#else
DG_STATIC_ASSERT(false, "Invalid CUDA version");
#endif
}
} // namespace `deep_gemm`

View File

@@ -0,0 +1,43 @@
#pragma once
#include <cute/arch/mma_sm100_desc.hpp>
namespace deep_gemm {
enum class MmaKind {
BF16 = 0,
MXFP8FP4 = 1,
};
constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) {
switch (mma_kind) {
case MmaKind::BF16: return 2;
case MmaKind::MXFP8FP4: return 1;
default: return 0;
}
}
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
Batched = 4,
MGroupedContiguousWithPsumLayout = 5,
};
constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) {
switch (gemm_type) {
case GemmType::MGroupedContiguous: return true;
case GemmType::MGroupedContiguousWithPsumLayout: return true;
default: return false;
}
}
enum class KernelType {
Kernel1D1D = 0,
Kernel1D2D = 1,
KernelNoSF = 2
};
} // namespace deep_gemm

View File

@@ -0,0 +1,41 @@
#pragma once
namespace deep_gemm {
enum class MmaKind {
BF16 = 0,
MXFP8FP4 = 1,
};
constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) {
switch (mma_kind) {
case MmaKind::BF16: return 2;
case MmaKind::MXFP8FP4: return 1;
default: return 0;
}
}
enum class GemmType {
Normal = 0,
MGroupedContiguous = 1,
MGroupedMasked = 2,
KGroupedContiguous = 3,
Batched = 4,
MGroupedContiguousWithPsumLayout = 5,
};
constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) {
switch (gemm_type) {
case GemmType::MGroupedContiguous: return true;
case GemmType::MGroupedContiguousWithPsumLayout: return true;
default: return false;
}
}
enum class KernelType {
Kernel1D1D = 0,
Kernel1D2D = 1,
KernelNoSF = 2
};
} // namespace deep_gemm

View File

@@ -0,0 +1,50 @@
#pragma once
#include <cuda/std/cstdint>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::utils {
template <typename FuncT>
struct PatternVisitor {
FuncT func;
CUTLASS_HOST_DEVICE
explicit PatternVisitor(FuncT&& func): func(std::forward<FuncT>(func)) {}
CUTLASS_HOST_DEVICE
auto operator [](const uint32_t& i) const {
return func(i);
}
};
template <uint32_t kNumBytes>
struct Vectorized {
static auto zeros() {
// TODO: add `ulonglong4` for SM100 once `__ldg` support this
if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) {
return make_uint4(0, 0, 0, 0);
} else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) {
return make_uint2(0, 0);
} else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) {
return 0;
} else {
DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization");
}
}
using vec_t = decltype(zeros());
};
template <uint32_t kNumCols>
CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() {
DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns");
if constexpr (kNumCols <= 32) return 32;
if constexpr (kNumCols <= 64) return 64;
if constexpr (kNumCols <= 128) return 128;
if constexpr (kNumCols <= 256) return 256;
return 512;
}
} // namespace deep_gemm::utils

View File

@@ -0,0 +1,137 @@
#pragma once
#include <cute/atom/copy_traits_sm100.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
namespace deep_gemm::epilogue {
template <uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
uint32_t kSwizzleCDMode,
uint32_t kNumTMAStoreStages,
uint32_t kNumUMMAStoreThreads,
GemmType kGemmType, bool kWithAccumulation,
typename cd_dtype_t,
typename epilogue_type_t,
typename pattern_cd_t>
CUTLASS_DEVICE void
sm100_store_cd(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
const uint32_t& tmem_base_addr,
const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
const cute::TmaDescriptor& tensor_map_cd) {
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Share store pipeline between blocks
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Iterate over M waves
constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M;
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]);
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = base_m_idx + w * STORE_BLOCK_M;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(base_n_idx + s * STORE_BLOCK_N);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = tmem_base_addr + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = smem_base_ptr + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
ptx::st_shared(
smem_ptr,
math::cast_into_bf16_and_pack(values[0], values[1]),
math::cast_into_bf16_and_pack(values[2], values[3]),
math::cast_into_bf16_and_pack(values[4], values[5]),
math::cast_into_bf16_and_pack(values[6], values[7])
);
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
ptx::tcgen05_before_thread_sync();
tmem_empty_barrier->arrive(0u);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kGemmType == GemmType::Batched) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx);
} else {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx);
}
cute::tma_store_arrive();
}
__syncwarp();
}
}
}
} // namespace deep_gemm::epilogue

View File

@@ -0,0 +1,144 @@
#pragma once
#include <cute/atom/copy_traits_sm100.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
namespace deep_gemm::epilogue {
template <uint32_t BLOCK_M, uint32_t BLOCK_N,
uint32_t STORE_BLOCK_M, uint32_t STORE_BLOCK_N,
uint32_t kSwizzleCDMode,
uint32_t kNumTMAStoreStages,
uint32_t kNumUMMAStoreThreads,
GemmType kGemmType, bool kWithAccumulation,
typename cd_dtype_t,
typename epilogue_type_t,
typename pattern_cd_t>
CUTLASS_DEVICE void
sm100_store_cd_swap_ab(const utils::PatternVisitor<pattern_cd_t>& smem_cd, uint32_t& tma_stage_idx,
const uint32_t& tmem_base_addr,
const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx,
const uint32_t& effective_m,
const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx,
const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier,
const cute::TmaDescriptor& tensor_map_cd) {
// NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows,
// implying STORE_BLOCK_N must be 128.
DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows");
// TMA checks
constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumSwizzleAtomRows = 8;
DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled");
DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling");
DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling");
// Share store pipeline between blocks
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Iterate over M blocks
const auto num_stores = effective_m / STORE_BLOCK_M;
for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) {
uint32_t tmem_addr = tmem_base_addr +
s * STORE_BLOCK_M + // Store stage offset
i * kNumSwizzleAtomRows; // In-block offset
uint32_t values[kNumSwizzleAtomRows];
// Warps cooperatively write an atomic block to shared memory
DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes");
constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32;
uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode;
uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode;
auto smem_base_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset;
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// NOTES: Swizzling is not required in this case, but used here for consistency with other cases
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
uint32_t col = lane_idx / 4;
#pragma unroll
for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) {
auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
+ (col ^ row) * kNumBankGroupBytes
+ (lane_idx % 4) * sizeof(float);
ptx::st_shared(reinterpret_cast<uint32_t*>(smem_ptr), values[row]);
}
} else {
// Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements
// Start from lane index 0
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
// Start from lane index 16
cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000,
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
// Destination shared memory address
uint32_t row = lane_idx % 8;
uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8;
auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8)
+ (col ^ row) * kNumBankGroupBytes;
// Store matrix with transposition
ptx::SM90_U32x4_STSM_T<int>::copy(math::cast_into_bf16_and_pack(values[0], values[1]),
math::cast_into_bf16_and_pack(values[2], values[3]),
math::cast_into_bf16_and_pack(values[4], values[5]),
math::cast_into_bf16_and_pack(values[6], values[7]),
smem_ptr);
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (s == num_stores - 1) {
ptx::tcgen05_before_thread_sync();
tmem_empty_barrier->arrive(0u);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) {
auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM;
uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M;
uint32_t n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N_ATOM>(base_n_idx + i * STORE_BLOCK_N_ATOM);
// Issue 2D or 3D TMA store
if constexpr (kGemmType == GemmType::Batched) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx);
} else {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx);
}
}
cute::tma_store_arrive();
}
__syncwarp();
}
}
} // namespace deep_gemm::epilogue

View File

@@ -0,0 +1,24 @@
#pragma once
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::epilogue::transform {
struct EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
return n_idx;
}
};
template <uint32_t kLeft, uint32_t kMid, uint32_t kRight>
struct EpilogueHeadSplits: EpilogueIdentity {
template <uint32_t STORE_BLOCK_N>
CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) {
DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and
kRight % STORE_BLOCK_N == 0, "Invalid head splits config");
return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid;
}
};
} // namespace deep_gemm::epilogue::transform

View File

@@ -0,0 +1,437 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/scheduler/gemm.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/epilogue/sm100_store_cd.cuh>
#include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages_,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
bool kSwapAB,
GemmType kGemmType, bool kWithAccumulation, typename cd_dtype_t,
uint64_t kTensorCoreUtilControl>
CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_bf16_gemm_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
// Enlarge `BLOCK_K` for some cases
// NOTES: this is for reducing the `umma_arrive()` overhead
constexpr bool kDoMergeStages =
kNumStages_ >= 8 and kGemmType == GemmType::Normal and
kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K;
// Ensure there are at least `kNumMinStages` stages after merge
constexpr uint32_t kNumMinStages = 8;
constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1;
constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge;
constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
// MMA Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
constexpr uint32_t UMMA_K = 16;
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K");
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
(not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
// Epilogue configs
// Always enable pipeline for better performance
constexpr uint32_t kNumEpilogueStages = 2;
constexpr uint32_t kNumTMAStoreStages = 2;
// NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
// per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
"Shared memory of A/B must be aligned to 1024 bytes");
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
// NOTES: Make sure we have enough shared memory for UMMA padding
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA");
// Real tensor memory size and offsets
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols>();
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Synchronize the cluster before 2-CTA TMEM allocation
kNumMulticast > 1 ? cute::cluster_sync() : void();
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_cd);
}
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// D/A/B shared memory
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); });
auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2;
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive only at the leader CTA
full_barriers[i]->init(kNumMulticast);
// Arrive at all CTAs
empty_barriers[i]->init(1);
}
#pragma unroll
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
// Arrive at all CTAs
tmem_full_barriers[i]->init(1);
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
}
if constexpr (kTensorCoreUtilControl < 100)
tensor_core_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(
shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = (stage_idx + 1) % kNumStages;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Use dynamic load block M, when swap-AB is enabled
const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
// For k-grouped layout, the number of block K is variable
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (is_leader_cta) {
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast);
} else {
full_barriers[stage_idx]->arrive(0u);
}
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
UMMA_M, UMMA_N, kMajorB, kMajorA>()
: cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
// Merged stages only happens in NT normal GEMM cases
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
ptx::tcgen05_after_thread_sync();
// UMMA and empty barrier arrival alias
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
__syncwarp();
};
// Dynamic update of UMMA N based on effective M, when swap-AB is enabled
if constexpr (kSwapAB) {
uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
}
// Launch MMAs
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
ptx::tcgen05_after_thread_sync();
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, ptx::SM100_MMA_F16BF16_SS, ptx::SM100_MMA_F16BF16_2x1SM_SS>;
const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K;
a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t>(
a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t>(
b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K);
if (kSwapAB) {
mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc);
} else {
mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc);
}
}
}
__syncwarp();
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
// Let tensor cores relax for lower possibility of frequency drop
DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control");
if constexpr (kTensorCoreUtilControl < 100) {
// For utilization control
umma_arrive(reinterpret_cast<uint64_t*>(tensor_core_full_barrier));
__syncwarp();
// Wait for last UMMA to be done
tensor_core_full_barrier->wait(tensor_core_phase);
tensor_core_phase ^= 1;
// Sleep for certain cycles
constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull;
constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl;
const auto start_clock = clock64();
if (cute::elect_one_sync())
while (clock64() - start_clock < kNumDummyCycles) {}
__syncwarp();
}
}
}
// To safely deconstruct barriers, we need another round of waits
const auto iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
// Epilogue warp groups
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
ptx::tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
const auto base_m_idx = scheduler.template get_global_idx<
(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto base_n_idx = n_block_idx * BLOCK_N;
if constexpr (kSwapAB) {
const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
epilogue::sm100_store_cd_swap_ab<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue::transform::EpilogueIdentity>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
effective_m,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
} else {
epilogue::sm100_store_cd<BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue::transform::EpilogueIdentity>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
}
}
}
// TODO: Remove redundant synchronization
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Deallocate tensor memory
if (warp_idx == 0)
Allocator().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,271 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cute/util/type_traits.hpp>
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kSwizzleABMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages, uint32_t kNumThreads>
CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1)
sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t kNumTMAStoreStages = 2;
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = ptx::get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size");
DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode");
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Shared memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t);
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_N>();
// Fill D/A/B
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE));
});
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2);
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 2 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(1);
}
tmem_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Block indices
const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx == 0) {
// TMA load warp
for (uint32_t s = 0; s < num_total_stages; ++ s) {
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
uint32_t m_idx = BLOCK_M * m_block_idx;
uint32_t n_idx = BLOCK_N * n_block_idx;
uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
uint32_t k_idx = sk_idx % SHAPE_K;
uint32_t s_idx = sk_idx / SHAPE_K;
// Issue TMAs
if (cute::elect_one_sync()) {
tma::copy<BLOCK_K, BLOCK_M, kSwizzleABMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleABMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N);
}
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
if (cute::elect_one_sync())
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
} else if (warp_idx == 1) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
constexpr uint32_t UMMA_M = LAYOUT_AD_M;
constexpr uint32_t UMMA_N = BLOCK_N;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t);
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::bfloat16_t, cutlass::bfloat16_t, float, UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_M, BLOCK_K, kSwizzleABMode>(smem_a[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, BLOCK_N, BLOCK_K, kSwizzleABMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Wait tensor memory empty barrier arrival
ptx::tcgen05_after_thread_sync();
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
ptx::tcgen05_after_thread_sync();
// Issue UMMA in the leader CTA
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx);
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx);
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
a_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_M, kSwizzleABMode, cutlass::bfloat16_t>(
a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = mma::sm100::advance_umma_desc_lo<cute::UMMA::Major::K, BLOCK_N, kSwizzleABMode, cutlass::bfloat16_t>(
b_desc_base_lo, 0, k * UMMA_K);
ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc);
}
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
}
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
if (warp_idx == 2)
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Wait UMMA arrival
tmem_full_barrier->wait(0);
ptx::tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s) {
// Wait shared memory to be released
if (s >= kNumTMAStoreStages) {
if (warp_idx == 0 and cute::elect_one_sync())
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier(kNumThreads).sync();
}
// The pipeline stage
const auto tma_stage_idx = s % kNumTMAStoreStages;
const auto m_idx = m_block_idx * BLOCK_M;
const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N;
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier(kNumThreads).sync();
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx);
cute::tma_store_arrive();
}
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is doing TMA stores
if (warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
}

View File

@@ -0,0 +1,457 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <uint32_t kNumHeads, uint32_t kHeadDim,
bool kIsCompressedLogits,
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t kNumSMs,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const uint32_t max_seqlen_k,
const uint32_t logits_stride,
const uint32_t* cu_seq_len_k_start,
const uint32_t* cu_seq_len_k_end,
logits_dtype_t* logits,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_sf_q);
cute::prefetch_tma_descriptor(&tensor_map_weights);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
}
// UMMA configs
static constexpr uint32_t kNumTmemStages = 3;
static constexpr uint32_t kNumUTCCPAlignedElems = 128;
static constexpr uint32_t UMMA_M = 128;
static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
static constexpr uint32_t UMMA_K = 64;
static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems);
static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems);
static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads;
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`");
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q and KV data on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
});
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
});
const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
});
auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
+ SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
// Tensor memory configs
constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQ / 32 + kNumSFKV / 32>();
constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
// Initialize barriers
if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(1);
}
#pragma unroll
for (uint32_t i = 0; i < kNumTmemStages; ++i) {
full_tmem_barriers[i]->init(1);
empty_tmem_barriers[i]->init(128);
}
cutlass::arch::fence_barrier_init();
}
// Allocate tensor memory
if (warp_idx == kSpecWarpStart + 2)
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
__syncthreads();
// Scheduler
const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple<uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv);
seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv);
start = cute::min(start, seq_k_start[i]);
end = cute::max(end, seq_k_end[i]);
}
// TMA alignment requirements for SF KV
start = start / 4 * 4;
return {start, math::ceil_div(end - start, BLOCK_KV)};
};
// Make Q, KV and TMEM pipeline
auto make_pipeline = [](const uint32_t& num_stages) {
// Return current stage and phase, and advance pipeline by steps
return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
uint32_t current_idx = iter_idx;
iter_idx += step;
return {current_idx % num_stages, (current_idx / num_stages) & 1};
};
};
auto advance_q_pipeline = make_pipeline(kNumQStages);
auto advance_kv_pipeline = make_pipeline(kNumKVStages);
auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 56;
constexpr uint32_t kNumMathRegisters = 224;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading Q
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
// Enumerate Q blocks
if (cute::elect_one_sync()) {
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Wait Q consumer release
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads);
tma::copy<BLOCK_Q * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q);
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q);
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
}
}
__syncwarp();
} else if (warp_idx == kSpecWarpStart + 1) {
// TMA warp for loading KV cache
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (cute::elect_one_sync()) {
// Enumerate Q blocks
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Load KV block ranges
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
// Enumerate KV blocks
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
// Wait KV consumer release
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
smem_sf_kv[kv_stage_idx],
kv_start + kv_idx * BLOCK_KV, 0);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
}
}
}
} else if (warp_idx == kSpecWarpStart + 2) {
// UMMA warp
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
// Enumerate Q blocks
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Load KV block ranges
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
// Wait TMA Q arrivals
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
// Transpose and copy SF Q
#pragma unroll
for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
if (cute::elect_one_sync())
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
__syncwarp();
}
// Enumerate KV blocks
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
// Wait TMA KV arrivals
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Transpose
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
}
// UMMA with SF
if (cute::elect_one_sync()) {
// Copy SF KV
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
// Wait TMEM release
CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
ptx::tcgen05_after_thread_sync();
// Issue UMMA with SF
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
// TODO: generalize umma desc
DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
auto a_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
auto b_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_q[q_stage_idx] + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
ptx::SM100_MMA_MXF4_SS::fma(
a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
}
// TODO: move this into `deep_gemm/ptx/tcgen05.cuh`
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
}
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
}
// UMMA warp must also arrive on empty_q to prevent running ahead
// of math warps in the Q pipeline. Without this, UMMA can consume
// kNumQStages Q blocks before math warps release any, causing a
// circular dependency: UMMA waits full_q -> TMA_Q waits empty_q
// -> Math waits full_tmem -> UMMA (already moved on).
empty_q_barriers[q_stage_idx]->arrive();
}
} else if (warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
const auto math_warpgroup_idx = warpgroup_idx;
const auto math_thread_idx = threadIdx.x;
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr uint32_t N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Math warpgroups process TMEM stages alternately
// Advance pipeline to align with the assigned stage
advance_tmem_pipeline(math_warpgroup_idx);
// Local register buffers
float accum[kNumHeads];
float weights[BLOCK_Q][kNumHeads];
// Enumerate Q blocks
for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) {
// Load KV block ranges
CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks);
// Wait TMA Q arrivals
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
// TODO: optimize bank conflicts
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
// Enumerate KV blocks
for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) {
// Calculate KV offset in advance
auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx;
// Advance pipeline by `kNumMathWarpGroups` steps
// Wait UMMA arrival
CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
ptx::tcgen05_after_thread_sync();
// Reduce over the head dim and store
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
// Release TMEM empty
if (i == BLOCK_Q - 1) {
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
// TODO: optimize performance
const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast<uint64_t>(logits_stride);
if constexpr (kIsCompressedLogits) {
if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
logits[q_offset + kv_offset - seq_k_start[i]] = result;
} else {
logits[q_offset + kv_offset] = result;
}
__syncwarp();
}
}
// Release last Q empty
empty_q_barriers[q_stage_idx]->arrive();
}
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,510 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp4_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_sf_q);
cute::prefetch_tma_descriptor(&tensor_map_weights);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_sf_kv);
}
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
// UMMA configs
static constexpr uint32_t kNumTmemStages = 3;
static constexpr uint32_t kNumUTCCPAlignedElems = 128;
static constexpr uint32_t UMMA_M = 128;
static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
static constexpr uint32_t UMMA_K = 64;
static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems);
static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems);
static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads;
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`");
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2);
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2);
static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q and KV data on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i;
});
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i;
});
const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages);
auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i);
});
auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i);
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages
+ SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; });
auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(tmem_barrier_ptr + kNumTmemStages * 2);
// Tensor memory configs
constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFQAtom / 32 + kNumSFKV / 32>();
constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
// Initialize barriers
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
cutlass::arch::fence_barrier_init();
}
if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(1);
}
cutlass::arch::fence_barrier_init();
}
if (warp_idx == kSpecWarpStart + 2) {
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumTmemStages; ++i) {
full_tmem_barriers[i]->init(1);
empty_tmem_barriers[i]->init(128);
}
cutlass::arch::fence_barrier_init();
}
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Make Q, KV and TMEM pipeline
auto make_pipeline = [](const uint32_t& num_stages) {
// Return current stage and phase, and advance pipeline by steps
return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple<uint32_t, uint32_t> {
uint32_t current_idx = iter_idx;
iter_idx += step;
return {current_idx % num_stages, (current_idx / num_stages) & 1};
};
};
auto advance_q_pipeline = make_pipeline(kNumQStages);
auto advance_kv_pipeline = make_pipeline(kNumKVStages);
auto advance_tmem_pipeline = make_pipeline(kNumTmemStages);
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 56;
constexpr uint32_t kNumMathRegisters = 224;
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading Q
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
if (cute::elect_one_sync()) {
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
// Persistently schedule over blocks
// Initialize outside valid range to indicate no previous task
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, _, __;
while (scheduler.fetch_next_task(q_atom_idx, _, __)) {
// Issue TMA Q when (q_idx, atom_idx) changes
if (q_atom_idx != last_q_atom_idx) {
// Wait Q consumer release
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx);
cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast<uint64_t*>(full_q_barriers[q_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_q[q_stage_idx], 0, q_token_idx * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx);
tma::copy<kNumHeads, kNextNAtom, 0>(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx);
full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE);
}
last_q_atom_idx = q_atom_idx;
}
}
__syncwarp();
} else if (warp_idx == kSpecWarpStart + 1) {
// TMA warp for loading KV cache
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
// Persistently schedule over blocks
uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage;
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, num_kv;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) {
// Reset block table cache on kv restart
if (q_atom_idx != last_q_atom_idx)
kv_block_idx_ptr = 32;
last_q_atom_idx = q_atom_idx;
// Coalesced load of block table
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
__syncwarp();
// Broadcast KV block indices
int kv_block_idx[kNumBlocksPerSplit];
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i)
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
kv_block_idx_ptr += kNumBlocksPerSplit;
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`");
// Wait KV consumer release
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
// Issue TMA KV
if (cute::elect_one_sync()) {
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
#pragma unroll
for (int i = 0; i < kNumBlocksPerSplit; ++ i) {
cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast<uint64_t*>(full_kv_barriers[kv_stage_idx]),
static_cast<uint64_t>(cute::TMA::CacheHintSm100::EVICT_NORMAL),
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i,
0, 0, kv_block_idx[i]);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx],
smem_sf_kv[kv_stage_idx] + BLOCK_KV * i,
0, kv_block_idx[i]);
}
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE);
}
}
} else if (warp_idx == kSpecWarpStart + 2) {
// UMMA warp
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<cutlass::float_e2m1_t, cutlass::float_e2m1_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
// Persistently schedule over blocks
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, _;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
// Wait TMA Q arrivals
uint32_t q_stage_idx, q_phase;
if (q_atom_idx != last_q_atom_idx) {
CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase);
// Release previous Q empty (UMMA warp must participate to prevent
// running ahead of math warps in the Q pipeline)
if (last_q_atom_idx != batch_size * kNumNextNAtoms)
empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
full_q_barriers[q_stage_idx]->wait(q_phase);
// Transpose and copy SF Q
#pragma unroll
for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
if (cute::elect_one_sync())
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4);
__syncwarp();
}
}
last_q_atom_idx = q_atom_idx;
// Wait TMA KV arrivals
CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Transpose
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
utccp_required_smem_warp_transpose(smem_ptr);
cutlass::arch::fence_view_async_shared();
}
// UMMA with SF
if (cute::elect_one_sync()) {
// Copy SF KV
#pragma unroll
for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4);
}
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
// Wait TMEM release
CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase);
uint32_t tmem_addr = tmem_stage_idx * UMMA_N;
empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1);
ptx::tcgen05_after_thread_sync();
// Issue UMMA with SF
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2);
// TODO: generalize UMMA desc
DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim");
auto a_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
auto b_desc = mma::sm100::make_smem_desc(
cute::UMMA::LayoutType::SWIZZLE_64B,
smem_q[q_stage_idx] + k * UMMA_K / 2,
8 * (kHeadDim / 2), 0);
ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc,
kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ);
}
// TODO: move this PTX into headers
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];"
::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx])));
}
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_kv_barriers[kv_stage_idx]));
}
} else if (warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
const auto math_warpgroup_idx = warpgroup_idx;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr int N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Math warpgroups process TMEM stages alternately
// Advance pipeline to align with the assigned stage
advance_tmem_pipeline(math_warpgroup_idx);
// Local register buffers
float accum[kNumHeads];
float weights[kNextNAtom][kNumHeads];
// Persistently schedule over blocks
uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms;
uint32_t q_atom_idx, kv_idx, _;
bool is_paired_atom = false;
while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) {
if (q_atom_idx != last_q_atom_idx) {
CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase);
// Release last Q empty
if (last_q_atom_idx != batch_size * kNumNextNAtoms)
empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive();
// Wait TMA Q arrivals
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j));
weights[i][j + 0] = raw.x;
weights[i][j + 1] = raw.y;
weights[i][j + 2] = raw.z;
weights[i][j + 3] = raw.w;
}
}
// Check if this atom pairs two tokens from the same sequence
if constexpr (kIsVarlen) {
is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2);
}
}
last_q_atom_idx = q_atom_idx;
// Calculate KV offset in advance
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx;
// Advance pipeline by `kNumMathWarpGroups` steps
// Wait UMMA arrival
CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase);
full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase);
ptx::tcgen05_after_thread_sync();
// Reduce over the head dim and store
const auto reduce_and_store = [&](auto num_iters_c) {
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
// Only loop over valid iterations
#pragma unroll
for (uint32_t i = 0; i < kNumIters; ++ i) {
// Load accumulator from TMEM
uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads;
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr, accum);
tmem_load(cute::Int<kNumHeads / 2>{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2);
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(sum.x + sum.y);
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride)] = result;
__syncwarp();
}
// Release TMEM empty
ptx::tcgen05_before_thread_sync();
empty_tmem_barriers[tmem_stage_idx]->arrive();
};
if constexpr (kIsVarlen) {
if (is_paired_atom)
reduce_and_store(cute::Int<kNextNAtom>{});
else
reduce_and_store(cute::Int<1>{});
} else if constexpr (kPadOddN) {
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
reduce_and_store(cute::Int<1>{});
else
reduce_and_store(cute::Int<kNextNAtom>{});
} else {
reduce_and_store(cute::Int<kNextNAtom>{});
}
}
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,514 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/epilogue/sm100_store_cd.cuh>
#include <deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kGranKA, uint32_t kGranKB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
bool kSwapAB,
GemmType kGemmType, bool kWithAccumulation,
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
typename epilogue_type_t>
CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
// MMA Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast;
constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N;
constexpr uint32_t UMMA_K = 32;
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or
(not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size");
// SF configs
constexpr uint32_t kNumUTCCPAlignedElems = 128;
constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB");
// Epilogue configs
// Always enable pipeline for better performance
constexpr uint32_t kNumEpilogueStages = 2;
constexpr uint32_t kNumTMAStoreStages = 2;
// NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N
// per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases
constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M;
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t);
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
"Shared memory of A/B must be aligned to 1024 bytes");
// NOTES: Make sure we have enough shared memory for UMMA padding
constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
// Tensor memory size and offsets
constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages;
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Synchronize the cluster before 2-CTA TMEM allocation
kNumMulticast > 1 ? cute::cluster_sync() : void();
// Utils
const bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_sfb);
cute::prefetch_tma_descriptor(&tensor_map_cd);
}
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4);
const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4);
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// D/A/B shared memory
auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// SFA/SFB shared memory
auto sf_start_ptr = reinterpret_cast<uint8_t*>(smem_b[kNumStages]);
auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
});
// Barriers and tensor memory pointer
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_sfb[kNumStages]);;
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
full_barriers[i]->init(1);
empty_barriers[i]->init(1);
// Arrive only at the leader CTA
with_sf_full_barriers[i]->init(kNumMulticast * 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
// Arrive at all CTAs
tmem_full_barriers[i]->init(1);
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs, kGranKA * 4>(
shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Use dynamic load block M, when swap-AB is enabled
const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M;
// For k-grouped layout, the number of block K is variable
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma::copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma::copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma::copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma::copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
uint32_t sfa_m_idx = m_block_idx * BLOCK_M;
uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>(
shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad));
tma::copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx);
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
uint32_t sfb_n_idx = n_block_idx * BLOCK_N;
uint32_t sfb_k_idx = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(
shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx);
tma::copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx);
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
}
// Arrive at full barriers
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled<b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorB, kMajorA>()
: cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto sf_desc = mma::sm100::make_sf_desc(nullptr);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = mma::sm100::make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
auto b_desc = mma::sm100::make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
ptx::tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
__syncwarp();
};
// Dynamic update of UMMA N based on effective M, when swap-AB is enabled
if constexpr (kSwapAB) {
uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx);
mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n);
}
// Launch MMAs
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
#pragma unroll 4
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[stage_idx]->wait(phase);
ptx::tcgen05_after_thread_sync();
const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx);
const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx);
if (cute::elect_one_sync()) {
// Do SF copy at certain stages
// TODO: process shared memory descriptor by addition
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
if (sfa_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
}
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
if (sfb_stage_in_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
// Issue UMMA
using mma_t = cute::conditional_t<
kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>;
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
const auto runtime_instr_desc = kSwapAB ?
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id):
mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
a_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K);
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
if constexpr (kSwapAB) {
mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc,
kTmemStartColOfSFB, kTmemStartColOfSFA);
} else {
mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N,
k_block_idx > 0 or k > 0, runtime_instr_desc,
kTmemStartColOfSFA, kTmemStartColOfSFB);
}
}
}
__syncwarp();
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
}
}
// To safely deconstruct barriers, we need another round of waits
const auto iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx == 2) {
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
// Transpose for UTCCP at certain stages
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Arrive
with_sf_full_barriers[stage_idx]->arrive(0u);
}
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
// Epilogue warp groups
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
ptx::tcgen05_after_thread_sync();
const auto tmem_base_addr = accum_stage_idx * UMMA_N;
const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto base_n_idx = n_block_idx * BLOCK_N;
if constexpr (kSwapAB) {
const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx);
epilogue::sm100_store_cd_swap_ab<
BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue_type_t>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
effective_m,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
} else {
epilogue::sm100_store_cd<
BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N,
kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads,
kGemmType, kWithAccumulation,
cd_dtype_t, epilogue_type_t>
(smem_cd, tma_stage_idx, tmem_base_addr,
base_m_idx, base_n_idx, scheduler.current_group_idx,
epilogue_warp_idx, lane_idx,
tmem_empty_barriers[accum_stage_idx],
tensor_map_cd);
}
}
}
// TODO: Remove redundant synchronization
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Deallocate tensor memory
if (warp_idx == 0)
Allocator().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,567 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/epilogue_utils.cuh>
#include <deep_gemm/common/scheduler.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/sm100_utils.cuh>
namespace deep_gemm {
using namespace deep_gemm::sm100;
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t kGranKA, uint32_t kGranKB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumGroups,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumNonEpilogueThreads, uint32_t kNumEpilogueThreads,
uint32_t kNumMulticast, bool kIsMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, bool kWithAccumulation,
typename a_dtype_t, typename b_dtype_t, typename cd_dtype_t,
typename epilogue_type_t>
__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1)
sm100_fp8_gemm_1d1d_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
using Allocator = cute::conditional_t<kNumMulticast == 1, cute::TMEM::Allocator1Sm, cute::TMEM::Allocator2Sm>;
// GEMM with accumulation must have FP32 output
if constexpr (kWithAccumulation)
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
// Configs
constexpr uint32_t LAYOUT_AD_M = 128;
constexpr uint32_t WAVE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M;
constexpr uint32_t kNumTMAStoreStages = 2;
constexpr uint32_t kNumUTCCPAlignedElems = 128;
DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K");
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M");
constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4;
constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4;
DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A");
DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4);
const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4);
// Utils
bool is_leader_cta = cute::block_rank_in_cluster() == 0;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// 2-CTA MMA
constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1);
constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t STORE_BLOCK_M = cute::min<uint32_t>(BLOCK_M, LAYOUT_AD_M);
constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t);
constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M;
DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast");
DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D");
DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast");
DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M");
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t);
constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems);
constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems);
constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t);
constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
"Shared memory of A/B must be aligned to 1024 bytes");
DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages");
// NOTES: Make sure we have enough shared memory for UMMA padding
static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t);
DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA");
// Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size
// TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2`
constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32;
constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32;
constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2;
// Real tensor memory size and offsets
constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N;
constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols<kNumAccumTmemCols + kNumSFATmemCols + kNumSFBTmemCols>();
constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols;
constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols;
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_sfb);
cute::prefetch_tma_descriptor(&tensor_map_cd);
}
// D/A/B shared memory
auto smem_cd = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cd_dtype_t*>(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE);
});
auto smem_a = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<a_dtype_t*>(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<b_dtype_t*>(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// SFA/SFB shared memory
auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = PatternVisitor([=](const uint32_t& i) {
return reinterpret_cast<uint32_t*>(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer +
SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) +
kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE));
auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); });
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
if (kNumMulticast > 1)
cute::cluster_sync();
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
// Arrive at all CTAs
full_barriers[i]->init(1);
empty_barriers[i]->init(1);
// Arrive only at the leader CTA
with_sf_full_barriers[i]->init(kNumMulticast * 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) {
// Arrive at all CTAs
tmem_full_barriers[i]->init(1);
// Arrive only at the leader CTA
tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumMulticast, kIsMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
// Dispatch warps into different roles
if (warp_idx == 0 and cute::elect_one_sync()) {
// TMA load warp
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// Compute offsets
// NOTES: the group is always concatenated with the outer dimension
uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> (
shape_m, BLOCK_M, m_block_idx);
uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> (
shape_n, BLOCK_N, n_block_idx, m_block_idx);
// NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major
// And for all m-grouped GEMMs, A must be K-majored
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or
kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_idx = k_block_idx * BLOCK_K;
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Add 2 CTA offsets
if constexpr (kNumMulticast > 1) {
m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0;
n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N);
}
// Issue TMAs
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode, a_dtype_t, kIsBatchedMM>(
&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma_copy<BLOCK_K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma_copy<LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode, b_dtype_t, kIsBatchedMM>(
&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx);
auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v<a_dtype_t, cutlass::float_e4m3_t> ? 1 : 2) +
SMEM_B_SIZE_PER_STAGE / (std::is_same_v<b_dtype_t, cutlass::float_e4m3_t> ? 1 : 2);
// Issue SFA and SFB TMAs at certain stages
// No swizzling, so one TMA for one SF is enough
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
tma_copy<BLOCK_M, 1, 0>(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M,
scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)));
num_arrival_bytes += BLOCK_M * sizeof(uint32_t);
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
tma_copy<BLOCK_N, 1, 0>(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N,
scheduler.template get_global_idx<true, IndexType::SF_K>(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx));
num_arrival_bytes += BLOCK_N * sizeof(uint32_t);
}
// Arrive at full barriers
full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes);
}
}
} else if (warp_idx == 1 and is_leader_cta) {
// MMA issue warp
// NOTES: only the leader CTA will do this
// Make instruction descriptor
// TODO: refactor `UMMA_M` calculation
constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast);
constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1);
constexpr uint32_t UMMA_K = 32;
auto instr_desc = cute::UMMA::make_instr_desc_block_scaled<a_dtype_t, b_dtype_t, float, cutlass::float_ue8m0_t,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
auto sf_desc = make_sf_desc(nullptr);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto a_desc = make_umma_desc<kMajorA, LOAD_BLOCK_M, BLOCK_K, kSwizzleAMode>(smem_a[0], 0, 0);
auto b_desc = make_umma_desc<kMajorB, LOAD_BLOCK_N, BLOCK_K, kSwizzleBMode>(smem_b[0], 0, 0);
uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u;
uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Wait tensor memory empty barrier arrival
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1);
tcgen05_after_thread_sync();
// Empty barrier arrival
auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) {
auto umma_arrive = [](const uint64_t* barrier) {
if constexpr (kNumMulticast == 1) {
cutlass::arch::umma_arrive(barrier);
} else {
constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1;
cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask);
}
};
umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
// NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting
if (do_tmem_full_arrive)
umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barriers[accum_stage_idx]));
};
// Launch MMAs
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA and SF-transpose arrival
with_sf_full_barriers[stage_idx]->wait(phase);
tcgen05_after_thread_sync();
// Do SF copy at certain stages
// NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves
// TODO: process shared memory descriptor by addition
using cute_utccp_t = cute::conditional_t<kNumMulticast == 1,
cute::SM100_UTCCP_4x32dp128bit_1cta, cute::SM100_UTCCP_4x32dp128bit_2cta>;
const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad;
if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4);
}
}
const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad;
if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) {
auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems;
replace_smem_desc_addr(sf_desc, smem_ptr);
cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4);
}
}
__syncwarp();
// Issue UMMA in the leader CTA
using mma_t = cute::conditional_t<kNumMulticast == 1, SM100_MMA_MXF8F6F4_SS, SM100_MMA_MXF8F6F4_2x1SM_SS>;
const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast<int>(stage_idx));
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx);
const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx);
const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id);
b_desc.lo = advance_umma_desc_lo<kMajorB, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K);
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset");
a_desc.lo = advance_umma_desc_lo<kMajorA, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K);
mma_t::fma(a_desc, b_desc,
accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N,
k_block_idx > 0 or k > 0,
runtime_instr_desc,
kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32),
kTmemStartColOfSFB);
}
}
}
// Commit to the mbarrier object
// No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit`
empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1);
}
}
// To safely deconstruct barriers, we need another round of waits
const auto& iter_idx = scheduler.current_iter - 1;
if (kNumMulticast > 1 and iter_idx >= 0) {
const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1;
tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx);
}
} else if (warp_idx == 2) {
// UTCCP transposer
auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) {
DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements");
uint32_t values[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx);
__syncwarp();
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i)
st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]);
};
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait TMA arrival
full_barriers[stage_idx]->wait(phase);
// Transpose for UTCCP at certain stages
if (k_block_idx % kNumSFAStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
if (k_block_idx % kNumSFBStagesPerLoad == 0) {
#pragma unroll
for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i)
utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems);
// TODO: figure out whether the proxy fence is valid for 2-CTA cases
cutlass::arch::fence_view_async_shared();
}
// Arrive
with_sf_full_barriers[stage_idx]->arrive(0u);
}
}
} else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) {
// Epilogue warp groups
const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32);
// NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits,
// i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`.
// NOTES: we also forbid two CTAs to share the same SM and its tensor memory
DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0);
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Share store pipeline between blocks
uint32_t tma_stage_idx = 0;
auto advance_store_pipeline = [&]() {
tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages;
};
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages;
auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1;
// Wait UMMA arrival
tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx);
tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough");
DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes");
// Iterate over M waves
#pragma unroll
for (uint32_t w = 0; w < kNumMWaves; ++ w) {
// Issue every swizzled atom and pipeline STSM and TMA store
constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N;
#pragma unroll
for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) {
// Wait shared memory to be released
if (epilogue_warp_idx == 0)
cute::tma_store_wait<kNumTMAStoreStages - 1>();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
// The pipeline stage
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M;
const auto n_idx = epilogue_type_t::apply_index_n<STORE_BLOCK_N>(n_block_idx * BLOCK_N + s * STORE_BLOCK_N);
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)`
// - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)`
// NOTES: "8" is the number of bank groups, "16" is the swizzling pattern
constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (i) : (bank_group_index % 8);
col ^= row % (kSwizzleCDMode / 16);
// Source and destination memory address
uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset
w * BLOCK_N + // Wave offset
s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd[tma_stage_idx]) + // Base pointer
epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
if constexpr (cute::is_same_v<cd_dtype_t, float>) {
// For FP32 output, read and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
} else {
// For BF16 output, read, cast and store
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr,
values[0], values[1], values[2], values[3],
values[4], values[5], values[6], values[7]);
cutlass::arch::fence_view_async_tmem_load();
st_shared(smem_ptr,
cast_into_bf16_and_pack(values[0], values[1]),
cast_into_bf16_and_pack(values[2], values[3]),
cast_into_bf16_and_pack(values[4], values[5]),
cast_into_bf16_and_pack(values[6], values[7]));
}
}
// Notify tensor memory empty (only at the leader CTA) arrival ASAP
// NOTES: only the last stage needs to do this
if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) {
tcgen05_before_thread_sync();
tmem_empty_barriers[accum_stage_idx]->arrive(0u);
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0);
if (epilogue_warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kGemmType == GemmType::Batched) {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_3D, cute::SM90_TMA_STORE_3D>;
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx],
n_idx, m_idx, scheduler.current_group_idx);
} else {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx);
}
cute::tma_store_arrive();
}
}
}
}
}
// Deallocate tensor memory
kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads();
if (warp_idx == 0)
Allocator().free(0, kNumTmemCols);
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,403 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <uint32_t kNumHeads, uint32_t kHeadDim,
bool kIsCompressedLogits,
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t kNumSMs,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const uint32_t max_seqlen_k, const uint32_t stride_logits,
uint32_t* cu_seq_len_k_start,
uint32_t* cu_seq_len_k_end,
logits_dtype_t* logits,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
// TODO: consider TMA multicast
// Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64`
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
// Q should be load only at once for a block
const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
// Types
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
// Shared memory configs
// NOTES: weight may be unaligned
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u);
// Align to 512 bytes for swizzle-64B
extern __shared__ __align__(512) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling");
// TMA configs
constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
// Data on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
});
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages +
SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
// TMA barriers
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); });
auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); });
// Tensor memory allocation
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2);
// Initialize barriers
DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads");
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(kNumMathThreads);
}
cutlass::arch::fence_barrier_init();
}
if (warp_idx == kSpecWarpStart + 1) {
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
full_umma_barriers[i]->init(1);
empty_umma_barriers[i]->init(128);
}
cutlass::arch::fence_barrier_init();
}
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Block scheduler
uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + kNumSMs, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = cu_seq_len_k_start[q_idx];
seq_k_end[i] = cu_seq_len_k_end[q_idx];
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
// TMA alignment requirements for SF KV
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
};
// KV pipeline
uint32_t num_total_kv_blocks = 0;
const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
return {
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
};
};
// UMMA settings
// Construct instruction with layout D
constexpr uint32_t UMMA_M = 128;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx == kSpecWarpStart) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
// Prefetch
const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
};
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
issue_tma_q(0, block_q_idx);
// Only the first lane persistently schedules over blocks
if (cute::elect_one_sync()) {
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
// Wait Q consumer release
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
issue_tma_q(q_stage_idx, next_block_q_idx);
// Issue TMA KV
#pragma unroll
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
// Wait consumer release
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
num_total_kv_blocks += num_kv_blocks;
// Jump to the next block
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
}
} else if (warp_idx == kSpecWarpStart + 1) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
// Require full allocation
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
// Wait TMA Q arrival
full_q_barriers[q_stage_idx]->wait(q_phase);
// Compute over KV blocks
#pragma unroll
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Issue UMMA
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size");
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1);
ptx::tcgen05_after_thread_sync();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_q[q_stage_idx], 0, k * UMMA_K);
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(full_umma_barriers[i]));
}
}
num_total_kv_blocks += num_kv_blocks;
// UMMA warp must also arrive on empty_q to prevent running ahead
// of math warps in the Q pipeline
empty_q_barriers[q_stage_idx]->arrive();
// Jump to the next block
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
} else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kSpecWarpStart) {
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// Offsets
const auto tmem_start = warpgroup_idx * UMMA_N;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr int N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Local register buffers
float weights[BLOCK_Q][kNumHeads];
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
// Wait TMA Q arrival
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
// Compute over KV blocks
#pragma unroll
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
// Wait UMMA arrival
full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1);
ptx::tcgen05_after_thread_sync();
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
// Reduce over the head dim and store
const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx;
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
// Load accumulator from TMEM
float accum[kNumHeads];
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Release TMEM empty
if (i == BLOCK_Q - 1) {
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[warpgroup_idx]->arrive();
}
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Store into the global memory
const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
if constexpr (kIsCompressedLogits) {
if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i])
logits[q_offset + kv_offset - seq_k_start[i]] = result;
} else {
logits[q_offset + kv_offset] = result;
}
__syncwarp();
}
}
num_total_kv_blocks += num_kv_blocks;
// Release Q empty
empty_q_barriers[q_stage_idx]->arrive();
// Jump to the next block
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,439 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads,
typename logits_dtype_t,
uint32_t kNumMathWarpGroups = kNumMathThreads / 128>
CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1)
void sm100_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Utils
const auto sm_idx = blockIdx.x;
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4;
// Prefetch TMA descriptors
DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
if (warp_idx == kSpecWarpStart) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
// For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill.
static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3);
static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1;
static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom);
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q and KV data on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i);
});
constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages;
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
// Barriers and TMEM pointer on shared memory
const auto barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; });
const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2;
auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; });
auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; });
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(umma_barrier_ptr + kNumMathWarpGroups * 2);
constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups;
DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory");
// Initialize barriers
if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads + 32);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(kNumMathThreads);
}
cutlass::arch::fence_barrier_init();
}
if (warp_idx == kSpecWarpStart + 1) {
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) {
full_umma_barriers[i]->init(1);
empty_umma_barriers[i]->init(128);
}
cutlass::arch::fence_barrier_init();
}
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumSpecializedRegisters = 56;
constexpr uint32_t kNumMathRegisters = 224;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Scheduler
constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV;
using Scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumBlocksPerSplit, kNumNextNAtoms>;
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`");
// Q and KV pipeline
const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
};
const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
};
// UMMA settings
// Construct instruction with layout D
constexpr uint32_t UMMA_M = 128;
constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t);
constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads;
DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
if (warp_idx == kSpecWarpStart) {
// TMA warp for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) {
if (cute::elect_one_sync()) {
const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx);
tma::copy<kHeadDim, kNextNAtom * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads);
tma::copy<kNextNAtom * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
}
};
// Initialize outside valid range to indicate no previous task
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv;
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
bool fetched_next_task;
// Prefetch the first Q
if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)))
issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1;
uint32_t kv_block_idx_ptr = 32;
uint32_t kv_block_idx_storage;
while (fetched_next_task) {
// Prefetch next Q when (q, atom) changes
const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size);
bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance);
if (q_atom_idx != next_q_atom_idx)
kv_block_idx_ptr = 32;
q_atom_idx = next_q_atom_idx;
kv_idx = next_kv_idx;
num_kv = next_num_kv;
// Read KV block index
// TODO(xuzhean): consider -1
if (kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast<uint64_t>(block_table_stride);
kv_block_idx_storage = (kv_idx + lane_idx < num_kv)
? block_table[block_table_offset + kv_idx + lane_idx] : 0;
}
__syncwarp();
DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`");
// Wait Q consumer release and issue TMA Q
if (prefetch_q) {
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
issue_tma_q(q_stage_idx, q_atom_idx + next_advance);
}
uint32_t kv_block_idx[kNumBlocksPerSplit];
#pragma unroll
for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i)
kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i);
kv_block_idx_ptr += kNumBlocksPerSplit;
// Wait KV consumer release
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
if (cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) {
tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i,
0, 0, 1, kv_block_idx[i]);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx] + BLOCK_KV * i,
0, kv_block_idx[i]);
}
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
// Fetch next task
fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv);
}
} else if (warp_idx == kSpecWarpStart + 1) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Require full allocation
DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0);
// Make UMMA desc
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::float_e4m3_t, cutlass::float_e4m3_t, float,
UMMA_M, UMMA_N, cute::UMMA::Major::K, cute::UMMA::Major::K>();
auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 1;
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
if (q_atom_idx != next_q_atom_idx) {
// Release previous Q empty (UMMA warp must participate to prevent
// running ahead of math warps in the Q pipeline)
if (q_iter_idx > 0)
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
}
q_atom_idx = next_q_atom_idx;
kv_idx = next_kv_idx;
// Wait KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) {
empty_umma_barriers[i]->wait(umma_phase);
ptx::tcgen05_after_thread_sync();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) {
auto a_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K);
auto b_desc = mma::sm100::make_umma_desc<cute::UMMA::Major::K, 0, kHeadDim, kHeadDim>(
smem_q[q_stage_idx], 0, k * UMMA_K);
cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc);
}
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(full_umma_barriers[i]));
}
umma_phase ^= 1;
}
} else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) {
cutlass::arch::warpgroup_reg_dealloc<kNumSpecializedRegisters>();
} else if (warp_idx < kSpecWarpStart) {
// Math warpgroups for reduce
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices);
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
// Offsets
const auto math_warpgroup_idx = warpgroup_idx;
const auto tmem_start = math_warpgroup_idx * UMMA_N;
const auto math_thread_idx = warp_idx * 32 + lane_idx;
// Helper lambda for loading tensor memory
auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) {
constexpr int N = decltype(num_elems_c)::value;
DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size");
using Loader = cute::conditional_t<N == 32,
cute::SM100_TMEM_LOAD_32dp32b32x,
cute::SM100_TMEM_LOAD_32dp32b64x>;
[&]<size_t... Is>(cute::index_sequence<Is...>) {
Loader::copy(tmem_addr, reinterpret_cast<uint32_t*>(accum)[Is]...);
}(cute::make_index_sequence<N>{});
cutlass::arch::fence_view_async_tmem_load();
};
// Local register buffers
float weights[kNextNAtom][kNumHeads];
// Initialize outside valid range to indicate no previous task
uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx;
uint32_t next_q_atom_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
uint32_t umma_phase = 0;
bool is_paired_atom = false;
while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) {
// Q or atom changes
if (q_atom_idx != next_q_atom_idx) {
// Release last Q empty
if (q_iter_idx > 0)
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
// Wait TMA Q arrival
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextNAtom; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j);
}
if constexpr (kIsVarlen) {
is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2);
}
}
// Get current task indices
q_atom_idx = next_q_atom_idx;
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast<uint64_t>(logits_stride) + kv_idx * BLOCK_KV;
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx);
// Wait UMMA arrival
full_umma_barriers[math_warpgroup_idx]->wait(umma_phase);
ptx::tcgen05_after_thread_sync();
umma_phase ^= 1;
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
// Reduce over the head dim and store
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
const auto reduce_and_store = [&](auto num_iters_c) {
constexpr uint32_t kNumIters = decltype(num_iters_c)::value;
float accum[kNumHeads];
#pragma unroll
for (uint32_t i = 0; i < kNumIters; ++ i) {
// Load accumulator from TMEM
tmem_load(cute::Int<kNumHeads>{}, tmem_start + i * kNumHeads, accum);
// Accumulate weighted ReLU in parallel
auto sum_0 = make_float2(0, 0);
auto sum_1 = make_float2(0, 0);
const auto transform = [&](const uint32_t& j, const float2& sum) {
auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0));
auto b = make_float2(weights[i][j], weights[i][j + 1]);
return __ffma2_rn(a, b, sum);
};
#pragma unroll
for (uint32_t j = 0; j < kNumHeads; j += 4) {
sum_0 = transform(j, sum_0);
sum_1 = transform(j + 2, sum_1);
}
auto sum = __fadd2_rn(sum_0, sum_1);
auto result = static_cast<logits_dtype_t>(scale_kv * (sum.x + sum.y));
// Store into the global memory
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + math_thread_idx] = result;
__syncwarp();
}
// Release TMEM empty
ptx::tcgen05_before_thread_sync();
empty_umma_barriers[math_warpgroup_idx]->arrive();
};
if constexpr (kIsVarlen) {
if (is_paired_atom)
reduce_and_store(cute::Int<kNextNAtom>{});
else
reduce_and_store(cute::Int<1>{});
} else if constexpr (kPadOddN) {
if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1)
reduce_and_store(cute::Int<1>{});
else
reduce_and_store(cute::Int<kNextNAtom>{});
} else {
reduce_and_store(cute::Int<kNextNAtom>{});
}
}
// Free tensor memory
cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync();
if (warp_idx == 0)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,350 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/mma/sm100.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tcgen05.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
CUTLASS_DEVICE
uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) {
// Calculate the index of the bank group to be written in the atom
const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)`
// - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)`
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups;
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
col ^= row % (kSwizzleMode / kSwizzleBase);
return row * 128 + col * kSwizzleBase;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumSplits,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMMAThreads, uint32_t kNumCastAndReduceThreads>
CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1)
sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
float* sqr_sum) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Configs
constexpr uint32_t kNumCastStages = 2;
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
constexpr auto kMajorA = cute::UMMA::Major::K;
constexpr auto kMajorB = cute::UMMA::Major::K;
DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages");
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads");
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = ptx::get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Real tensor memory size and offsets
constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols<BLOCK_K * kNumCastStages + BLOCK_N>();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE +
kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); });
auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); });
auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4;
// Fill the tensor memory pointer
auto tmem_ptr_in_smem = reinterpret_cast<uint32_t*>(barrier_start_ptr + kNumStages * 4 + 1);
DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns");
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
full_cast_barriers[i]->init(kNumCastAndReduceThreads);
empty_barriers[i]->init(1);
empty_cast_barriers[i]->init(1);
}
tmem_full_barrier->init(1);
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
} else if (warp_idx == 2) {
// Allocate tensor memory
cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem);
}
__syncthreads();
constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
const uint32_t m_block_idx = block_idx / kNumSplits;
const uint32_t k_split_idx = block_idx % kNumSplits;
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
const uint32_t m_offset = shape_m * k_split_idx;
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Dispatch warps into different roles
if (warp_idx < kNumMMAThreads / 32) {
// TMA load warp
if (warp_idx == 0 and cute::elect_one_sync()) {
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto& stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
// Compute offsets
uint32_t m_idx = m_block_idx * BLOCK_M;
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
}
// MMA issue warp
if (warp_idx == 1) {
// Make instruction descriptor
constexpr uint32_t UMMA_M = BLOCK_M;
constexpr uint32_t UMMA_N = BLOCK_N;
constexpr uint32_t UMMA_K = 32 / sizeof(float);
constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float);
using umma_t = cute::SM100_MMA_TF32_TS<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
BLOCK_M, BLOCK_N, kMajorA, kMajorB>;
auto instr_desc = cute::UMMA::make_instr_desc<cutlass::tfloat32_t, cutlass::tfloat32_t, float,
UMMA_M, UMMA_N, kMajorA, kMajorB>();
const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc);
DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages");
auto b_desc = mma::sm100::make_umma_desc<kMajorB, BLOCK_N, BLOCK_SWIZZLED_BK, kSwizzleBMode>(smem_b[0], 0, 0);
const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u;
// Checks for MMA instructions
// NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits
DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or
(UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256),
"Invalid MMA instruction shape");
// Launch MMAs
// We can not unroll this part
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
const auto& cast_stage_idx = s % kNumCastStages;
full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1);
ptx::tcgen05_after_thread_sync();
// Issue UMMA
const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast<int>(stage_idx));
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) {
const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK;
const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK;
const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK;
b_desc.lo = mma::sm100::advance_umma_desc_lo<kMajorB, BLOCK_N, kSwizzleBMode, float>(b_desc_base_lo, offset, in_atom_idx);
umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc);
}
// Commit
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_cast_barriers[cast_stage_idx]));
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(empty_barriers[stage_idx]));
}
// Commit to epilogue threads
cutlass::arch::umma_arrive(reinterpret_cast<uint64_t*>(tmem_full_barrier));
}
// TMA checks
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float);
DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled");
DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling");
// Only support layout F (M = 64) and D (M = 128)
DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M");
// Wait UMMA arrival
tmem_full_barrier->wait(0);
ptx::tcgen05_after_thread_sync();
// Load from tensor memory into registers, and write shared memory with STSM
DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough");
// Store into shared memory
#pragma unroll
for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) {
// Source and destination memory address
uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup;
auto smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) + // Base pointer
warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset
get_swizzled_smem_offset<kSwizzleCDMode>(i, lane_idx); // In-atom offset
// Load from tensor memory, store into shared memory
uint32_t values[kNumElemsPerBankGroup];
DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type");
cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr,
values[0], values[1], values[2], values[3]);
cutlass::arch::fence_view_async_tmem_load();
if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16))
ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]);
if constexpr (BLOCK_M == 64)
__syncwarp();
}
// Synchronize all threads and issue TMA
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0);
if (warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kNumSplits == 1) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
} else {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
}
cute::tma_store_arrive();
}
// Deallocate tensor memory by warp 1
// NOTES: warp 0 is waiting TMA store
if (warp_idx == 1)
cute::TMEM::Allocator1Sm().free(0, kNumTmemCols);
} else {
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads");
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32;
// TODO: make even larger block K
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
// Launch reductions
float2 sum[2] = {float2{0, 0}, float2{0, 0}};
#pragma unroll kNumStages
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
// Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b)
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
const auto& smem_base_ptr = reinterpret_cast<uint8_t*>(smem_a[stage_idx]) + // Base pointer
sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset
// 4 lanes shared a bank group
uint32_t uint32_values[2][kNumLoads];
DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads");
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; i += 2) {
auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset<kSwizzleAMode>(i + lane_idx / 16, lane_idx % 16);
ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0],
uint32_values[0][i + 1], uint32_values[1][i + 1],
smem_ptr);
}
// Wait tensor memory empty
const auto& cast_stage_idx = s % kNumCastStages;
empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1);
// Cast, reduce and store into tensor memory
float2 fp32x2_values[2][kNumLoads];
const auto& upper_view = reinterpret_cast<uint32_t*>(&fp32x2_values[0]);
const auto& lower_view = reinterpret_cast<uint32_t*>(&fp32x2_values[1]);
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; ++ i) {
#pragma unroll
for (uint32_t u = 0; u < 2; ++ u) {
fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast<nv_bfloat162*>(&uint32_values[u][i]));
sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]);
}
// Store upper and lower part at the same time
const auto idx_0 = i * 2, idx_1 = i * 2 + 1;
cute::SM100_TMEM_STORE_16dp256b1x::copy(
upper_view[idx_0], upper_view[idx_1],
lower_view[idx_0], lower_view[idx_1],
cast_stage_idx * BLOCK_K + i * 8);
}
cutlass::arch::fence_view_async_tmem_store();
// Arrive for issuing MMAs
ptx::tcgen05_before_thread_sync();
full_cast_barriers[cast_stage_idx]->arrive();
}
// Intra-warp reduction and write back
#pragma unroll
for (uint32_t u = 0; u < 2; ++ u) {
const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y);
const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8;
if (lane_idx % 4 == 0 and m_idx < shape_m)
sqr_sum[m_offset + m_idx] = reduced_sum;
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_100f");
#endif
}
} // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,388 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <cute/arch/mma_sm100_desc.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
template <cute::UMMA::Major kMajorA, cute::UMMA::Major kMajorB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K_,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
uint32_t kNumStages_,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, bool kWithAccumulation,
typename cd_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bf16_gemm_impl(int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Enlarge `BLOCK_K` for some cases
// NOTES: this is for reducing the `warpgroup_wait<0>()` overhead
constexpr uint32_t kDoMergeStages =
kNumStages_ >= 10 and
kGemmType == GemmType::Normal and
kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and
kNumMathThreads == 128;
// Ensure there are at least `kNumMinStages` stages after merge
constexpr uint32_t kNumMinStages = 5;
constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1;
constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge;
constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge;
// Types
using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N, kMajorA, kMajorB>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(cd_dtype_t)), 1024u);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
// NOTES: Make sure we have enough shared memory for WGMMA padding
static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_cd);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0,
"Shared memory of A/B/D must be aligned to 1024 bytes");
// D/A/B shared memory
auto smem_d = reinterpret_cast<cd_dtype_t*>(smem_buffer);
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<cutlass::bfloat16_t*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 48;
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
// We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32`
if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) {
DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group");
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const auto m_idx = scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx);
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major");
uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> (
shape_k, BLOCK_K, k_block_idx, m_block_idx);
// Issue TMAs
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
if constexpr (kMajorA == cute::UMMA::Major::K)
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx);
if constexpr (kMajorA == cute::UMMA::Major::MN)
tma::copy<BLOCK_M, BLOCK_K, kSwizzleAMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::K)
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx);
if constexpr (kMajorB == cute::UMMA::Major::MN)
tma::copy<BLOCK_N, BLOCK_K, kSwizzleBMode, cutlass::bfloat16_t, kIsBatchedMM>(
&tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
}
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
empty_barriers[stage_idx]->wait(phase ^ 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
// Merged stages only happens in NT normal GEMM cases
constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge;
auto a_desc = mma::sm90::make_gmma_desc<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode>(smem_a[0], math_wg_idx * WGMMA::M, 0);
auto b_desc = mma::sm90::make_gmma_desc<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode>(smem_b[0], 0, 0);
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2;
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
// Pick threads whose WGMMA results are to be stored in shared memory
DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`");
constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M);
const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32;
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};
// TODO: remove some useless computation for unaligned Ms
const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
// Wait TMA arrivals
full_barriers[stage_idx]->wait(phase);
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K;
a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorA, BLOCK_M, BLOCK_ATOM_K, kSwizzleAMode, nv_bfloat16>(
a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K);
b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo<kMajorB, BLOCK_N, BLOCK_ATOM_K, kSwizzleBMode, nv_bfloat16>(
b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K);
WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1);
}
}
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(stage_idx);
}
// TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions");
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
// Skip WGMMA store for the unfilled parts
if (not do_wgmma_store)
continue;
// Wait last TMA store to be finished
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
if constexpr (cute::is_same_v<cd_dtype_t, cutlass::bfloat16_t>) {
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type");
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset
constexpr uint32_t kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
col ^= row % (kSwizzleDMode / 16);
// Add back into the base pointer
// NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
m_offset * kSwizzleDMode + // Wave offset
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
}
// NOTES: only 16 lanes' addresses are used
ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
);
}
}
} else {
// Use `st.shared` if STSM is not available
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = accum + WGMMA::kNumAccum * local_idx;
auto smem_d_0 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2);
auto smem_d_1 = reinterpret_cast<float2*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2);
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]));
ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]));
}
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0);
// Use TMA store to write back to global memory
const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx);
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
if constexpr (kGemmType == GemmType::Batched) {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset,
m_idx, scheduler.current_group_idx);
} else {
using cute_tma_t = cute::conditional_t<kWithAccumulation,
cute::SM90_TMA_REDUCE_ADD_2D, cute::SM90_TMA_STORE_2D>;
cute_tma_t::copy(&tensor_map_cd, smem_ptr,
n_block_idx * BLOCK_N + in_block_n_offset, m_idx);
}
cute::tma_store_arrive();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,183 @@
#pragma once
#include <cute/arch/cluster_sm90.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSplitFactor,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
float *d) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Types
using WGMMA = typename mma::sm90::BF16MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Shared memory
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16);
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = ptx::get_lane_idx();
DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M");
DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads");
DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads");
// Prefetch TMA descriptors at the very beginning
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
// Fill shared memory pointers
extern __shared__ __align__(1024) uint8_t smem_buffer[];
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumMathThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 232;
// Block indices
const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N);
const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M);
const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks;
const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks;
const uint32_t n_block_idx = mn_block_idx % num_n_blocks;
const uint32_t m_block_idx = mn_block_idx / num_n_blocks;
const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
// Persistently schedule over blocks
#pragma unroll
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1);
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K;
const uint32_t k_idx = sk_idx % SHAPE_K;
const uint32_t s_idx = sk_idx / SHAPE_K;
constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16);
tma::copy<BLOCK_K, BLOCK_M, kSwizzle>(
&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1);
tma::copy<BLOCK_K, BLOCK_N, kSwizzle>(
&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
float accum[WGMMA::kNumAccum] = {0};
// Launch MMAs
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrivals
const auto stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, 1);
}
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
empty_barriers[stage_idx]->arrive();
}
const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4;
const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
if (col + i * 8 >= SHAPE_N)
break;
if (row < SHAPE_M) {
atomicAdd(reinterpret_cast<float2*>(d + (row + 0) * SHAPE_N + col + i * 8),
make_float2(accum[i * 4 + 0], accum[i * 4 + 1]));
}
if (row + 8 < SHAPE_M) {
atomicAdd(reinterpret_cast<float2*>(d + (row + 8) * SHAPE_N + col + i * 8),
make_float2(accum[i * 4 + 2], accum[i * 4 + 3]));
}
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm

View File

@@ -0,0 +1,346 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/int_tuple.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/tma.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
template <uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs,
GemmType kGemmType, typename cd_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr,
int* grouped_layout,
cute::TmaDescriptor* tensor_map_buffer,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a_base,
const __grid_constant__ cute::TmaDescriptor tensor_map_b_base,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfb,
const __grid_constant__ cute::TmaDescriptor tensor_map_cd) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads");
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(cute::is_same_v<cd_dtype_t, float>, "Invalid C/D data dtype");
DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type");
// Types
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0);
static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u);
DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment");
// Configs
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = threadIdx.x % 32;
// Prefetch TMA descriptors at the very beginning
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a_base);
cute::prefetch_tma_descriptor(&tensor_map_b_base);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_sfb);
cute::prefetch_tma_descriptor(&tensor_map_cd);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Tensor maps on shared and global memory
auto smem_tensor_map_a = reinterpret_cast<cute::TmaDescriptor*>(smem_buffer);
auto smem_tensor_map_b = smem_tensor_map_a + 1;
auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2;
auto gmem_tensor_map_b = gmem_tensor_map_a + 1;
// Data on shared memory
auto smem_d = reinterpret_cast<float*>(smem_buffer + SMEM_TENSOR_MAP_SIZE);
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE));
});
auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE));
});
// Barriers on shared memory
constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE);
auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast<uint32_t>(sizeof(Barrier))));
});
auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<Barrier*>(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast<uint32_t>(sizeof(Barrier))));
});
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
// Load tensormap A/B to shared memory
if constexpr (kGemmType == GemmType::KGroupedContiguous) {
*smem_tensor_map_a = tensor_map_a_base;
*smem_tensor_map_b = tensor_map_b_base;
}
// Initialize barriers
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// Pipeline unroll control
constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages);
// Register reconfigurations (more math registers are needed with unrolling)
constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24);
constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs, 128u>(shape_m, shape_n, shape_k, grouped_layout);
// TMA and MMA pipeline
const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase
};
uint32_t iter_idx = 0;
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
uint32_t last_group_idx = kNumGroups;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K);
const uint32_t m_idx = m_block_idx * BLOCK_M;
const uint32_t n_idx = n_block_idx * BLOCK_N;
if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) {
last_group_idx = scheduler.current_group_idx;
// Directly update current tensor map
const uint64_t current_k_offset = scheduler.current_k_cumsum;
ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m);
ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n);
ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k);
ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k);
*(gmem_tensor_map_a) = *(smem_tensor_map_a);
*(gmem_tensor_map_b) = *(smem_tensor_map_b);
ptx::tensor_map_release_gpu();
// Immediately acquire current tensor map
ptx::tensor_map_acquire_gpu(gmem_tensor_map_a);
ptx::tensor_map_acquire_gpu(gmem_tensor_map_b);
}
#pragma unroll kNumPipelineUnrolls
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
// Wait consumer release
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
empty_barriers[stage_idx]->wait(phase ^ 1);
// Issue TMA
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t k_idx = k_block_idx * BLOCK_K;
const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx;
const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base);
const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base);
tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a);
tma::copy<BLOCK_N, BLOCK_K, 0>(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b);
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE);
}
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
#pragma unroll
for (uint32_t s = 0; s < kNumStages; ++ s) {
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
empty_barriers[stage_idx]->wait(phase ^ 1);
}
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4;
const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8;
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Accumulation for WGMMA or CUDA promotion
DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes");
const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k);
const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0);
const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K);
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0};
float2 scales_b[WGMMA::kNumAccum / 4];
// Empty barrier arrival
auto empty_barrier_arrive = [&](uint32_t s) {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[s]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void();
}
};
#pragma unroll kNumPipelineUnrolls
for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) {
// Wait TMA arrivals
CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase);
full_barriers[stage_idx]->wait(phase);
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0);
auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1);
// Read B scales
#pragma unroll
for (int i = 0; i < WGMMA::kNumAccum / 4; ++i)
scales_b[i] = ptx::ld_shared(reinterpret_cast<float2*>(smem_sfb[stage_idx] + i * 8 + col_idx * 2));
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1);
auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival
empty_barrier_arrive(stage_idx);
// Promote with scales
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
const float &scale_b_0 = scales_b[i].x;
const float &scale_b_1 = scales_b[i].y;
final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0];
final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1];
final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2];
final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3];
}
}
// Flush previous stores
if (warp_idx % 4 == 0 and cute::elect_one_sync())
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
// Store to D shared memory
const auto smem_d_0 = reinterpret_cast<float2*>(smem_d + r_0 * BLOCK_N + col_idx * 2);
const auto smem_d_1 = reinterpret_cast<float2*>(smem_d + r_1 * BLOCK_N + col_idx * 2);
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]});
ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]});
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, math_wg_idx);
// Use TMA store to write back to global memory
if (warp_idx % 4 == 0 and cute::elect_one_sync()) {
cute::SM90_TMA_REDUCE_ADD_2D::copy(
&tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N,
current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0);
cute::tma_store_arrive();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,449 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/copy_sm90_tma.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/epilogue/transform.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/gemm.cuh>
namespace deep_gemm {
template <uint32_t kNumFormerIters, uint32_t kGap, uint32_t kEnd, typename func_t>
CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) {
if (num_former_iters == kNumFormerIters) {
func(cute::Int<kNumFormerIters>{});
return;
}
if constexpr (kNumFormerIters + kGap <= kEnd)
dispatch_num_former_iters<kNumFormerIters + kGap, kGap, kEnd>(num_former_iters, func);
}
template <cute::UMMA::Major kMajorSFB,
uint32_t SHAPE_M, uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t kNumGroups,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kSwizzleAMode, uint32_t kSwizzleBMode, uint32_t kSwizzleDMode,
uint32_t kNumStages,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
uint32_t kNumTMAMulticast, bool kIsTMAMulticastOnA,
uint32_t kNumSMs, GemmType kGemmType,
typename epilogue_type_t>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void
sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout,
uint32_t shape_m, uint32_t shape_n, uint32_t shape_k,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
// Scaling checks
DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling");
DG_STATIC_ASSERT(
math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or
(math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block");
// Types
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_N>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size");
// Overwrite shape constants if the compiler gives
shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m;
shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n;
shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k;
// Shared memory
static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0);
static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast<uint32_t>(sizeof(__nv_bfloat16)), 1024u);
static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u);
const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K);
const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K);
const uint32_t smem_sfb_size = math::align<uint32_t>(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier));
// NOTES: Make sure we have enough shared memory for WGMMA padding
static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3);
DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA");
// Configs
const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K);
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const uint32_t lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors at the very beginning
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_sfa);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
__syncwarp();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
// Data on shared memory
auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer);
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE);
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE);
});
constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE);
auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
});
auto smem_sfb = reinterpret_cast<float*>(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE);
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(reinterpret_cast<uint8_t*>(smem_sfb) + smem_sfb_size);
auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; });
auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; });
// Initialize barriers
DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast");
if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) {
// NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster,
// even with TMA multicast disabled, we want to make the behavior aligned
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
// Synchronize all threads to make barrier visible in normal memory model
(kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads();
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Block scheduler
uint32_t m_block_idx, n_block_idx;
auto scheduler = sched::Scheduler<kGemmType, BLOCK_M, BLOCK_N, kNumGroups, kNumTMAMulticast, kIsTMAMulticastOnA, kNumSMs>(shape_m, shape_n, shape_k, grouped_layout);
// Pipeline and TMA phases
uint32_t stage_idx = 0, phase = 0;
auto advance_pipeline = [&](uint32_t& k_block_idx) {
++ k_block_idx;
// Flip phases only if reach the next first stage
stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1;
phase ^= stage_idx == 0;
};
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// NOTES: only one thread (or warp) will be used
// We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32`
if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) {
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Assign TMA multicast number into A and B
// NOTES: there may be additional odd rows/columns or cases where multicast is not possible.
const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx);
const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1;
DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast");
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
// Wait consumer release
empty_barriers[stage_idx]->wait(phase ^ 1);
// Issue TMA A
constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched);
const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0);
constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked;
auto& full_barrier = *full_barriers[stage_idx];
const uint32_t k_idx = k_block_idx * BLOCK_K;
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_a, &full_barrier,
smem_a[stage_idx], k_idx, scheduler.get_global_idx<kWithGroupOffsetA>(shape_m, BLOCK_M, m_block_idx),
num_tma_multicast_a, batch_idx);
tma::copy<BLOCK_M, BLOCK_K, 0>(&tensor_map_sfa, &full_barrier,
smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx<kWithGroupOffsetA, sched::IndexType::SF_K>(shape_k_scales, 1, k_block_idx),
num_tma_multicast_a);
// Issue TMA B
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode, __nv_fp8_e4m3, kIsBatchedMM>(&tensor_map_b, &full_barrier,
smem_b[stage_idx], k_idx, scheduler.get_global_idx<true>(shape_n, BLOCK_N, n_block_idx, m_block_idx),
num_tma_multicast_b, batch_idx);
full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE);
}
}
// To safely deconstruct distributed shared barriers, we need another round of empty waits
if constexpr (kNumTMAMulticast > 1) {
for (uint32_t i = 0; i < kNumStages; advance_pipeline(i))
empty_barriers[stage_idx]->wait(phase ^ 1);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0);
const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8;
auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1);
auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1);
const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0);
const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0);
// Persistently schedule over blocks
while (scheduler.get_next_block(m_block_idx, n_block_idx)) {
// Decide the number of scales B to load
DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0);
uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters;
if constexpr (not kMustUseUniformedScaleB) {
num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8;
num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8;
}
uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2);
// Load B scales with math warp-groups
// NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks
if (threadIdx.x >= 32) {
auto previous_group_offset = scheduler.template get_global_idx<true, sched::IndexType::SF_K>(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx);
const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales;
const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1;
auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb;
#pragma unroll
for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32)
ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]);
}
cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0);
// Accumulation for WGMMA or CUDA promotion
constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2;
DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes");
float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0};
// Pick threads whose WGMMA results are to be stored in shared memory
DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`");
constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M);
const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32;
// Empty barrier arrival
auto empty_barrier_arrive = [&]() {
if constexpr (kNumTMAMulticast == 1) {
lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void();
} else {
auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster();
lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void();
}
};
// Skip useless computations
if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) {
// The compiler must know the dynamic variable `num_former_iters`'s real value
constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB;
constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8;
constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0;
// Dispatch `num_former_iters` and launch MMAs
dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) {
#pragma unroll 8
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16);
const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16);
// Read B scales
float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1;
// NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks
if constexpr (not kMustUseUniformedScaleB)
scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales);
// Wait TMA arrivals
full_barriers[stage_idx]->wait(phase);
// TODO: remove some useless computation for unaligned Ms
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
// Read A scales
// NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results
auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0;
auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0;
// Commit WGMMA instructions
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) {
a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16;
b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16;
WGMMA::wgmma(a_desc, b_desc, accum, k);
}
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Notify barrier arrival at the last warpgroup wave
if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1)
empty_barrier_arrive();
// Skip promotion for the unfilled parts
if (not do_wgmma_store)
continue;
// Promote with scales
// NOTES: making it as predicates is very important for performance, comparing to two loops
float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0;
float scale_0_1, scale_1_1;
if constexpr (not kMustUseUniformedScaleB)
scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant
const bool predicate = kMustUseUniformedScaleB or i < num_former_iters;
shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0];
shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1];
shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2];
shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3];
}
}
}
});
} else {
#pragma unroll
for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) {
full_barriers[stage_idx]->wait(phase);
empty_barrier_arrive();
}
}
// TMA checks
constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16);
constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes);
constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4;
DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom");
DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32,
"Unaligned TMA store or too many TMA store instructions");
DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N");
// Skip WGMMA store for the unfilled parts
if (not do_wgmma_store)
continue;
// Wait last TMA store to be finished
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
cute::tma_store_wait<0>();
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1);
// Write back to shared memory using STSM and issue TMA stores
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
#pragma unroll
for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) {
auto m_offset = local_idx * WAVE_BLOCK_M;
auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx;
#pragma unroll
for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) {
// Swizzle or padding into the correct address
uint8_t* smem_ptr = nullptr;
if constexpr (kSwizzleDMode > 0) {
// Calculate the swizzling atom offset and in-atom offset
constexpr uint32_t kNumBankGroupBytes = 16;
auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8);
// Calculate the index of the bank group to be written in the atom
auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes);
// Reshape the atom in another view and swizzle
// - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)`
// - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)`
constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8;
auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8);
auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8);
col ^= row % (kSwizzleDMode / 16);
// Add back into the base pointer
// NOTES: think twice before modifying this, as changes may affect the number of instructions
smem_ptr = reinterpret_cast<uint8_t*>(smem_d) + // Base pointer
warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset
m_offset * kSwizzleDMode + // Wave offset
atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants)
row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset
} else {
// No swizzling, just padding
smem_ptr = reinterpret_cast<uint8_t*>(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8);
}
// NOTES: only 16 lanes' addresses are used
ptx::SM90_U32x2_STSM_N<nv_bfloat162>::copy(
__float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}),
__float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}),
smem_ptr
);
}
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1);
// Use TMA store to write back to global memory
// TODO: compatible with FP32 output
constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked;
DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks");
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) {
auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N;
auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M;
auto n_idx = epilogue_type_t::apply_index_n<TMA_D_BLOCK_N>(n_block_idx * BLOCK_N + in_block_n_offset);
auto m_idx = scheduler.get_global_idx<kWithGroupOffsetD>(shape_m, BLOCK_M, m_block_idx);
if constexpr (kGemmType == GemmType::Batched) {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr,
n_idx, m_idx, scheduler.current_group_idx);
} else {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx);
}
cute::tma_store_arrive();
}
__syncwarp();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
}; // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,330 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <cute/arch/mma_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
namespace deep_gemm {
template <uint32_t kNumHeads, uint32_t kHeadDim,
bool kIsCompressedLogits,
uint32_t BLOCK_Q, uint32_t BLOCK_KV,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t kNumSMs,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
typename logits_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv,
const uint32_t max_seqlen_k, const uint32_t stride_logits,
uint32_t* cu_seq_len_k_start,
uint32_t* cu_seq_len_k_end,
logits_dtype_t* logits,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
// TODO: consider TMA multicast
// For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]`
// Q should be load only at once for a block
const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q);
// Types
using WGMMA = typename mma::sm90::FP8MMASelector<BLOCK_Q * kNumHeads>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// Prefetch TMA descriptors
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
__syncwarp();
// Shared memory configs
// NOTES: weight may be unaligned
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Data on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i));
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer +
SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages +
SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
// TMA barriers
auto barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); });
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); });
// Initialize barriers
const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32;
if (is_tma_load_warp and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
}
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(kNumMathThreads);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 32;
constexpr uint32_t kNumMathRegisters = 112;
// Block scheduler
const auto sm_idx = blockIdx.x;
uint32_t block_q_idx = sm_idx, q_iter_idx = 0;
const auto get_next_block_q_idx = [&]() -> cute::tuple<uint32_t, uint32_t> {
return {block_q_idx + kNumSMs, q_iter_idx + 1};
};
uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q];
const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple<uint32_t, uint32_t, uint32_t, uint32_t> {
uint32_t start = cute::numeric_limits<uint32_t>::max();
uint32_t end = cute::numeric_limits<uint32_t>::min();
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1);
seq_k_start[i] = cu_seq_len_k_start[q_idx];
seq_k_end[i] = cu_seq_len_k_end[q_idx];
start = min(start, min(seq_k_start[i], seq_len_kv));
end = max(end, min(seq_k_end[i], seq_len_kv));
}
// TMA alignment requirements for SF KV
start = start / 4 * 4;
return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage
((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase
start, math::ceil_div(end - start, BLOCK_KV)}; // Task info
};
// KV pipeline
uint32_t num_total_kv_blocks = 0;
const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple<uint32_t, uint32_t> {
return {
(num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage
((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase
};
};
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (threadIdx.x >= kNumMathThreads) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
// Only the first warp remains
if (not is_tma_load_warp)
return;
// Prefetch
const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) {
tma::copy<kHeadDim, BLOCK_Q * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads);
tma::copy<kNumHeads, BLOCK_Q, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
};
if (cute::elect_one_sync() and block_q_idx < num_q_blocks)
issue_tma_q(0, block_q_idx);
// Only the first lane persistently schedules over blocks
if (cute::elect_one_sync()) {
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks);
// Wait Q consumer release
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
// Issue TMA Q
if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks)
issue_tma_q(q_stage_idx, next_block_q_idx);
// Issue TMA KV
#pragma unroll
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
// Wait consumer release
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
tma::copy<kHeadDim, BLOCK_KV, kHeadDim>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
num_total_kv_blocks += num_kv_blocks;
// Jump to the next block
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto& thread_idx = threadIdx.x % kNumMathThreads;
const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0);
const auto& warpgroup_idx = warp_idx / 4;
const auto& lane_idx = ptx::get_lane_idx();
float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4];
const auto& warp_offset = warp_idx * 16;
const auto& v_0_offset = lane_idx / 4 + 0;
const auto& v_1_offset = lane_idx / 4 + 8;
while (block_q_idx < num_q_blocks) {
CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks);
// Wait TMA Q arrival
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
}
// Compute over KV blocks
#pragma unroll
for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) {
// Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Read per-KV scales
float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset);
float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset);
// Issue WGMMA
DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size");
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
auto desc_a = mma::sm90::make_smem_desc(
smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_b = mma::sm90::make_smem_desc(
smem_q[q_stage_idx] + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_wait<0>();
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
// Reduce over the head dim and store
const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset;
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation");
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
#pragma unroll
for (uint32_t i = 0; i < BLOCK_Q; ++ i) {
auto shifted_accum = accum + i * kNumAccumPerReduce;
const auto transform = [&](const uint32_t& j) {
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
};
// Intra-thread reduction
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
#pragma unroll
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
#pragma unroll
for (uint32_t k = 0; k < 4; k ++)
sum[k] += transform(j * 4 + k);
}
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
// Inter-thread reduction
#pragma unroll
for (uint32_t j = 0; j < 2; ++ j) {
const auto& offset = static_cast<int>(1u << j);
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
}
// Store into the global memory
const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast<uint64_t>(stride_logits);
if constexpr (kIsCompressedLogits) {
if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i])
logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_0);
if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i])
logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast<logits_dtype_t>(v_1);
} else {
logits[q_offset + kv_offset + v_0_offset] = static_cast<logits_dtype_t>(v_0);
logits[q_offset + kv_offset + v_1_offset] = static_cast<logits_dtype_t>(v_1);
}
}
}
num_total_kv_blocks += num_kv_blocks;
// Release Q empty
empty_q_barriers[q_stage_idx]->arrive();
// Jump to the next block
CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx);
}
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,334 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/arch/copy_sm90_desc.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
#include <deep_gemm/scheduler/paged_mqa_logits.cuh>
namespace deep_gemm {
template <uint32_t kNextN, uint32_t kNumHeads,
uint32_t kHeadDim, uint32_t BLOCK_KV,
bool kIsContextLens2D, bool kIsVarlen,
uint32_t kNumQStages, uint32_t kNumKVStages,
uint32_t SPLIT_KV,
uint32_t kNumTMAThreads, uint32_t kNumMathThreads,
typename logits_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1)
void sm90_fp8_paged_mqa_logits(const uint32_t batch_size,
const uint32_t logits_stride, const uint32_t block_table_stride,
const uint32_t* context_lens, logits_dtype_t* logits,
const uint32_t* block_table, const uint32_t* indices,
const uint32_t* schedule_meta,
const __grid_constant__ cute::TmaDescriptor tensor_map_q,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv,
const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales,
const __grid_constant__ cute::TmaDescriptor tensor_map_weights) {
DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits");
// Types
using WGMMA = typename mma::sm90::FP8MMASelector<kNextN * kNumHeads>::type;
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// NOTES: use `__shfl_sync` to encourage NVCC to use unified registers
const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const auto warpgroup_idx = warp_idx / 4;
const auto lane_idx = ptx::get_lane_idx();
// Prefetch TMA descriptors
static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128;
DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads");
DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`");
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_q);
cute::prefetch_tma_descriptor(&tensor_map_kv);
cute::prefetch_tma_descriptor(&tensor_map_kv_scales);
cute::prefetch_tma_descriptor(&tensor_map_weights);
}
__syncwarp();
// Shared memory configs
static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8;
static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) +
math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3);
static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float);
static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment);
static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) +
math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment);
// Align to swizzling alignment bytes
extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[];
DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling");
// Q data and barriers on shared memory
auto smem_q = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i);
});
auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i);
});
auto q_barrier_ptr = reinterpret_cast<Barrier*>(smem_weights[kNumQStages]);
auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; });
auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); });
// Separate math warpgroups and tma load warps into KV groups
// Each math warpgroup corresponds to a tma load warp
const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0);
// Per group KV data and barriers on shared memory
const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx;
auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i);
});
auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i);
});
auto kv_barrier_ptr = reinterpret_cast<Barrier*>(smem_kv_scales[kNumKVStages]);
auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; });
auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; });
// Initialize barriers
if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) {
if (kv_group_idx == 0) {
#pragma unroll
for (uint32_t i = 0; i < kNumQStages; ++ i) {
full_q_barriers[i]->init(1);
empty_q_barriers[i]->init(kNumMathThreads);
}
}
if (kv_group_idx < kNumMathWarpGroups) {
#pragma unroll
for (uint32_t i = 0; i < kNumKVStages; ++ i) {
full_kv_barriers[i]->init(1);
empty_kv_barriers[i]->init(128);
}
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
__syncthreads();
// Register reconfigurations
constexpr uint32_t kNumTMARegisters = 64;
constexpr uint32_t kNumMathRegisters = 104;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Scheduler
auto scheduler = sched::PagedMQALogitsScheduler<kNextN, kIsContextLens2D, kIsVarlen, BLOCK_KV, kNumMathWarpGroups, 1>(
blockIdx.x, batch_size, context_lens, schedule_meta, indices);
DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV");
// Q and KV pipeline
const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase
};
const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple<uint32_t, uint32_t> {
return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase
};
uint32_t q_iter_idx = 0, kv_iter_idx = 0;
if (warp_idx >= kNumMathThreads / 32) {
// TMA warp-group for loading data
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
if (kv_group_idx >= kNumMathWarpGroups)
return;
const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) {
if (kv_group_idx == 0 and cute::elect_one_sync()) {
tma::copy<kHeadDim, kNextN * kNumHeads, kHeadDim>(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads);
tma::copy<kNextN * kNumHeads, 1, 0>(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN);
full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE);
}
};
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
uint32_t q_idx = batch_size, kv_idx, num_kv;
uint32_t next_q_idx, next_kv_idx, next_num_kv;
bool fetched_next_task;
// Prefetch the first Q
if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)))
issue_tma_q(0, next_q_idx), q_iter_idx = 1;
int kv_block_idx_ptr = 32;
uint32_t kv_block_idx_storage;
while (fetched_next_task) {
// Prefetch next Q when current Q changes
bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1));
q_idx = next_q_idx;
kv_idx = next_kv_idx;
num_kv = next_num_kv;
// Wait Q consumer release and issue TMA Q
if (prefetch_q) {
CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1);
issue_tma_q(q_stage_idx, q_idx + 1);
}
// Read KV block index
// TODO: deal with `-1`?
if (kv_idx == 0 or kv_block_idx_ptr == 32) {
kv_block_idx_ptr = 0;
kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ?
block_table[q_idx * static_cast<uint64_t>(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0);
}
const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++);
// Wait KV consumer release
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1);
// Issue TMA KV
if (cute::elect_one_sync()) {
tma::copy<kHeadDim, BLOCK_KV, 0, __nv_fp8_e4m3, true>(&tensor_map_kv, full_kv_barriers[kv_stage_idx],
smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx);
tma::copy<BLOCK_KV, 1, 0>(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx],
smem_kv_scales[kv_stage_idx], 0, kv_block_idx);
full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE);
}
// Fetch next task
fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv);
}
} else {
// Math warp-groups for WGMMA
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4];
const auto sub_warp_offset = (warp_idx % 4) * 16;
const auto v_0_offset = lane_idx / 4 + 0;
const auto v_1_offset = lane_idx / 4 + 8;
// Initialize `q_idx` outside `[0, batch_size)` to indicate it was none
uint32_t q_idx = batch_size, kv_idx;
uint32_t next_q_idx, next_kv_idx, next_num_kv;
uint32_t q_stage_idx, q_phase;
while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) {
// Current Q changes
if (q_idx != next_q_idx) {
// Release Last Q empty
if (q_iter_idx > 0)
empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive();
// Wait TMA Q arrival
CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase);
full_q_barriers[q_stage_idx]->wait(q_phase);
// Read weights
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
#pragma unroll
for (uint32_t j = 0; j < kNumHeads / 4; ++ j)
weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2);
}
}
// Get current Q and KV index
q_idx = next_q_idx;
kv_idx = next_kv_idx;
// Calculate KV offset in advance
auto kv_offset = q_idx * kNextN * static_cast<uint64_t>(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset);
// Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]`
// Wait TMA KV arrival
CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase);
full_kv_barriers[kv_stage_idx]->wait(kv_phase);
// Issue WGMMA
DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size");
DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim");
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
#pragma unroll
for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) {
auto desc_a = mma::sm90::make_smem_desc(
smem_kv[kv_stage_idx] + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
auto desc_b = mma::sm90::make_smem_desc(
smem_q[q_stage_idx] + k * WGMMA::K,
mma::sm90::to_swizzle_cute_type<kHeadDim>(), 0, kHeadDim * 8);
WGMMA::wgmma(desc_a, desc_b, accum, k);
}
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
// Read per-KV scales
float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset);
float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset);
// Wait WGMMA
ptx::warpgroup_wait<0>();
// Release KV empty
empty_kv_barriers[kv_stage_idx]->arrive();
// Reduce over the head dim and store
static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2;
DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation");
DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation");
DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head");
#pragma unroll
for (uint32_t i = 0; i < kNextN; ++ i) {
auto shifted_accum = accum + i * kNumAccumPerReduce;
const auto transform = [&](const uint32_t& j) {
return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)];
};
// Intra-thread reduction
float sum[4] = {transform(0), transform(1), transform(2), transform(3)};
#pragma unroll
for (uint32_t j = 1; j < kNumHeads / 8; ++ j) {
#pragma unroll
for (uint32_t k = 0; k < 4; k ++)
sum[k] += transform(j * 4 + k);
}
float v_0 = (sum[0] + sum[1]) * scale_kv_0;
float v_1 = (sum[2] + sum[3]) * scale_kv_1;
// Inter-thread reduction
#pragma unroll
for (uint32_t j = 0; j < 2; ++ j) {
const auto offset = static_cast<int>(1u << j);
v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset);
v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset);
}
// Store into the global memory
// NOTES: we have redundant writes here, consider more carefully
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_0_offset] = static_cast<logits_dtype_t>(v_0);
logits[kv_offset + i * static_cast<uint64_t>(logits_stride) + v_1_offset] = static_cast<logits_dtype_t>(v_1);
}
}
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,294 @@
#pragma once
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunknown-attributes"
#include <cutlass/arch/barrier.h>
#include <cutlass/arch/reg_reconfig.h>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/common/tma_copy.cuh>
#include <deep_gemm/common/types.cuh>
#include <deep_gemm/mma/sm90.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
#include <deep_gemm/ptx/wgmma.cuh>
namespace deep_gemm {
template <uint32_t kSwizzleMode, uint32_t kSwizzleBase = 16>
CUTLASS_DEVICE
uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) {
constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase;
const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange;
constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase;
constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups;
auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups);
auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups);
col ^= row % kGroupsInSwizzleRange;
return (row * kNumBankGroups + col) % kGroupsInSwizzleRange;
}
template <uint32_t SHAPE_N, uint32_t SHAPE_K,
uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K,
uint32_t kNumSplits,
uint32_t kSwizzleCDMode,
uint32_t kNumStages,
uint32_t kNumMathThreads, uint32_t kNumTMAThreads>
CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1)
sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m,
const __grid_constant__ cute::TmaDescriptor tensor_map_a,
const __grid_constant__ cute::TmaDescriptor tensor_map_b,
const __grid_constant__ cute::TmaDescriptor tensor_map_d,
float* sqr_sum) {
#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__)
using Barrier = cutlass::arch::ClusterTransactionBarrier;
// kSwizzleAMode and kSwizzleBMode must be 128 for now
constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128);
constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128);
DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K");
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode");
DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N");
DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads");
// Utils
const auto warp_idx = cutlass::canonical_warp_idx_sync();
const auto lane_idx = ptx::get_lane_idx();
// Align to 1024 bytes for swizzle-128B
extern __shared__ __align__(1024) uint8_t smem_buffer[];
// Share memory sizes
constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode;
constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16);
constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float);
DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes");
if (warp_idx == 0 and cute::elect_one_sync()) {
cute::prefetch_tma_descriptor(&tensor_map_a);
cute::prefetch_tma_descriptor(&tensor_map_b);
cute::prefetch_tma_descriptor(&tensor_map_d);
}
// Data on shared memory (layout as ordered below)
// Fill D/A/B pointers
auto smem_cd = reinterpret_cast<float*>(smem_buffer);
auto smem_a = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<nv_bfloat16*>(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE));
});
auto smem_b = utils::PatternVisitor([&](const uint32_t& i) {
return reinterpret_cast<float*>(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE));
});
// Fill barriers
auto barrier_start_ptr = reinterpret_cast<Barrier*>(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE));
auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); });
auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); });
// Initialize barriers
if (warp_idx == 1 and cute::elect_one_sync()) {
#pragma unroll
for (uint32_t i = 0; i < kNumStages; ++ i) {
full_barriers[i]->init(1);
empty_barriers[i]->init(128);
}
// Make initialized barrier visible in async proxy
cutlass::arch::fence_barrier_init();
}
__syncthreads();
constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K);
constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits;
constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits;
const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0);
const uint32_t m_block_idx = block_idx / kNumSplits;
const uint32_t k_split_idx = block_idx % kNumSplits;
const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K;
const uint32_t m_offset = shape_m * k_split_idx;
const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks);
constexpr uint32_t kNumTMARegisters = 40;
constexpr uint32_t kNumMathRegisters = 256;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// TMA load warp
if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) {
cutlass::arch::warpgroup_reg_dealloc<kNumTMARegisters>();
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait consumer release
const auto stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
// Compute offsets
uint32_t m_idx = m_block_idx * BLOCK_M;
uint32_t k_idx = k_offset + s * BLOCK_K;
// Issue TMAs
tma::copy<BLOCK_K, BLOCK_M, kSwizzleAMode>(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx);
tma::copy<BLOCK_K, BLOCK_N, kSwizzleBMode>(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0);
// Arrive at full barriers
constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE;
full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes);
}
for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) {
const auto stage_idx = s % kNumStages;
empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1);
}
} else if (warp_idx < kNumMathThreads / 32) {
cutlass::arch::warpgroup_reg_alloc<kNumMathRegisters>();
DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M");
DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K");
constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4;
constexpr uint32_t WGMMA_M = 64;
constexpr uint32_t WGMMA_N = BLOCK_N;
constexpr uint32_t WGMMA_K = 8;
using WGMMA = typename mma::sm90::TF32MMASelector<WGMMA_N, true>::type;
float accum[WGMMA::kNumAccum] = {0};
constexpr uint32_t kNumBankGroupBytes = 16;
constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16);
constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup;
float sqr_sum_acc_0 = 0;
float sqr_sum_acc_1 = 0;
#pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2
for (uint32_t s = 0; s < num_total_stages; ++ s) {
// Wait TMA arrival
const auto& stage_idx = s % kNumStages;
full_barriers[stage_idx]->wait((s / kNumStages) & 1);
constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128;
constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K;
float a[kNumRegPerWgmma * kNumWgmmaPerBlockK];
// Assume swizzle A mode is 128
DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode");
// Load BF16 A fragment from shared memory into registers, and transpose to FP32
uint32_t row = warp_idx * 16 + lane_idx / 4;
#pragma unroll
for (uint32_t i = 0; i < kNumLoads; ++ i) {
// Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a
uint32_t bank_group_idx = (row ^ i) % 8;
nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup;
uint32_t elem_offset = lane_idx % 4;
nv_bfloat16 a_bf16[kNumRegPerWgmma];
a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset];
a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4];
a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset];
a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4];
auto a_bf16x2_ptr = reinterpret_cast<nv_bfloat162*>(a_bf16);
auto a_float2_ptr = reinterpret_cast<float2*>(a);
float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]);
float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]);
a_float2_ptr[i * 2 + 0] = a_float2_0;
a_float2_ptr[i * 2 + 1] = a_float2_1;
sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x;
sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y;
}
ptx::warpgroup_wait<0>();
if (s > 0)
empty_barriers[(s - 1) % kNumStages]->arrive();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
ptx::warpgroup_arrive();
constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float);
constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K;
DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K");
#pragma unroll
for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) {
#pragma unroll
for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) {
auto b_desc = mma::sm90::make_smem_desc(
smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1);
WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1);
}
}
ptx::warpgroup_commit_batch();
#pragma unroll
for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i)
ptx::warpgroup_fence_operand(accum[i]);
}
const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0);
const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1);
const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4);
if (lane_idx % 4 == 0) {
if (m_idx < shape_m)
sqr_sum[m_offset + m_idx] = reduced_sum_0;
if (m_idx + 8 < shape_m)
sqr_sum[m_offset + m_idx + 8] = reduced_sum_1;
}
ptx::warpgroup_wait<0>();
empty_barriers[(num_total_stages-1) % kNumStages]->arrive();
// Write accum to shared memory
// Every 2 threads (one pair) will write to the same bank group (16 bytes).
// Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d
uint32_t is_odd_pair = lane_idx / 2 % 2;
// Four threads per group; write the data to the same row.
uint32_t row_idx = lane_idx / 4;
// Even/odd index pairs write to the same column, we need to reorder idx:
// group even pair indices consecutively, and likewise for odd ones.
uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx;
auto shifted_smem_ptr = reinterpret_cast<uint8_t*>(smem_cd) +
(warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows
lane_idx % 2 * 8; // One thread of a pair writes 8 bytes
#pragma unroll
for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) {
// Get the swizzled bank group index (16 bytes per group)
uint32_t bank_group_idx = get_swizzled_bank_group_idx<kSwizzleCDMode>(i + is_odd_pair, reordered_pair_idx);
auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group
// 0/1 write to the same row, 2/3 write to another row
auto values = reinterpret_cast<uint32_t*>(accum + i * 2);
ptx::st_shared(smem_ptr, values[0], values[1]);
ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]);
}
cute::tma_store_fence();
cutlass::arch::NamedBarrier::sync(128, 1);
// Issue TMA stores
if (warp_idx == 0 and cute::elect_one_sync()) {
if constexpr (kNumSplits == 1) {
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M);
} else {
cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx);
}
cute::tma_store_arrive();
}
}
#else
if (blockIdx.x == 0 and threadIdx.x == 0)
DG_DEVICE_ASSERT(false and "This kernel only support sm_90a");
#endif
}
} // namespace deep_gemm
#pragma clang diagnostic pop

View File

@@ -0,0 +1,74 @@
#pragma once
#include <cutlass/arch/barrier.h>
#include <cute/arch/cluster_sm90.hpp>
#include <deep_gemm/common/cute_tie.cuh>
#include <deep_gemm/common/math.cuh>
namespace deep_gemm {
template <uint32_t kNextN, uint32_t BLOCK_KV, uint32_t kNumWarps, typename logits_dtype_t>
CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1)
void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits,
const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) {
const uint32_t num_sms = gridDim.x;
const uint32_t sm_idx = blockIdx.x;
const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t);
const logits_dtype_t neg_inf = -cute::numeric_limits<logits_dtype_t>::infinity();
// Allocate filled `-inf` shared memory
extern __shared__ __align__(1024) logits_dtype_t smem_buffer[];
#pragma unroll
for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32)
smem_buffer[i] = neg_inf;
cute::tma_store_fence();
__syncthreads();
// Assign sequence to each warp
const auto assign_task = [&](const uint32_t& num, const uint32_t& idx,
const uint32_t& start, const uint32_t& total) -> cute::tuple<uint32_t, uint32_t> {
const auto per = total / num, rem = total % num;
return {start + idx * per + cute::min(idx, rem), per + (idx < rem)};
};
CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len);
CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len);
// Wait for primary kernel completion
cudaGridDependencySynchronize();
if (cute::elect_one_sync()) {
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) {
const auto right = cute::min(left + BLOCK_KV, static_cast<uint32_t>(stride_logits));
if (right <= ks or ke <= left) {
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t));
} else {
if (left < aligned_ks)
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t));
if (aligned_ke < right)
cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t));
}
}
}
}
__syncwarp();
for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) {
const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN];
const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1;
const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment;
for (uint32_t j = aligned_ks; j < ks; ++ j)
logits[i * stride_logits + j] = neg_inf;
for (uint32_t j = ke; j < aligned_ke; ++ j)
logits[i * stride_logits + j] = neg_inf;
}
}
}

View File

@@ -0,0 +1,189 @@
#pragma once
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/utils.cuh>
#include <deep_gemm/ptx/ld_st.cuh>
#include <deep_gemm/ptx/utils.cuh>
namespace deep_gemm {
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K,
uint32_t PADDED_SF_K = SF_K + (1 - (SF_K % 2))>
CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) {
typedef typename utils::Vectorized<sizeof(float) * SF_K>::vec_t in_vec_t;
constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float);
constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec;
// Shapes and strides
extern __shared__ float smem_buffer[];
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(float));
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
const auto tma_aligned_mn = math::align<uint32_t>(mn, kNumTMAAlignedElems);
// Shift into the block
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * SF_K;
const auto& local_sf = reinterpret_cast<const in_vec_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Load
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) {
auto in_vec = local_sf[i];
const auto& in_values = reinterpret_cast<float*>(&in_vec);
const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec;
#pragma unroll
for (uint32_t j = 0; j < kNumElemsPerVec; ++ j)
smem_buffer[row * PADDED_SF_K + col + j] = in_values[j];
}
__syncthreads();
// Store
#pragma unroll
for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) {
const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn;
const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx;
out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx);
}
}
// NOTES: the two kernels below always pack the K dimension
template <uint32_t kNumThreads, uint32_t BLOCK_MN, uint32_t SF_K>
CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) {
extern __shared__ uint32_t smem_buffer[];
// Shapes and strides
constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u);
constexpr auto kNumTMAAlignedElems = static_cast<uint32_t>(16 / sizeof(int));
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
const auto tma_aligned_mn = math::align<uint64_t>(mn, kNumTMAAlignedElems);
// Shift into the group
sf = sf + static_cast<uint64_t>(blockIdx.y) * mn * SF_K;
out = out + static_cast<uint64_t>(blockIdx.y) * tma_aligned_mn * kNumPackedSFK;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Load FP32 SFs
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size");
const auto local_sf = reinterpret_cast<uint32_t*>(sf + static_cast<uint64_t>(blockIdx.x) * (BLOCK_MN * SF_K));
const auto num_values = in_block_mn * SF_K;
const auto num_uint4 = num_values / 4;
#pragma unroll
for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) {
const auto& [x, y, z, w] = reinterpret_cast<const uint4*>(local_sf)[i];
ptx::st_shared(reinterpret_cast<uint4*>(smem_buffer) + i, x, y, z, w);
}
// Fill unaligned values as well
if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values)
ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]);
__syncthreads();
// Pack into UE8M0 and store
#pragma unroll
for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) {
const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN;
// Load shared memory
uint32_t values[4];
#pragma unroll
for (uint32_t j = 0; j < 4; ++ j) {
const auto sf_k_idx = sf_k_pack_idx * 4 + j;
values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0;
}
// Pack and store
uint32_t packed = 0;
packed |= (values[0] >> 23u);
packed |= (values[1] >> 15u);
packed |= (values[2] >> 7u);
packed |= (values[3] << 1u);
if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn)
out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed;
}
}
template <uint32_t kNumGroups, uint32_t kNumThreads,
uint32_t BLOCK_MN, uint32_t BLOCK_PACKED_SF_K, bool kTransposed = true>
CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks,
const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k,
const uint32_t gran_k) {
// Always packing the K dimension
// NOTES: should also assert `mn % 4 == 0` at launch
DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)");
DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes");
DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes");
// Shapes and strides
const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN);
const auto in_block_mn_uint4 = in_block_mn / 4;
const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K);
// Shift into the right block along MN
sf += blockIdx.x * BLOCK_MN;
out += blockIdx.x * BLOCK_MN;
// Each warp is responsible for a packed row
const auto warp_idx = threadIdx.x / 32;
const auto lane_idx = ptx::get_lane_idx();
const auto packed_sf_k_idx = static_cast<uint64_t>(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx;
if (warp_idx >= in_block_packed_sf_k)
return;
// Wait for primary kernel completion
cudaGridDependencySynchronize();
// Make an offset on the input
uint32_t input_offset = 0;
if constexpr (kNumGroups > 1) {
// Load each group's size
DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups");
uint32_t group_ks[4];
#pragma unroll
for (uint32_t i = 0; i < 4; ++ i) {
const auto group_idx = lane_idx * 4 + i;
group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0;
}
__syncwarp();
// Make the offset
sf_k = 0;
uint32_t sum_packed_sf_k = 0;
#pragma unroll
for (uint32_t i = 0; i < kNumGroups; ++ i) {
const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4);
sf_k += sf_k_in_group;
sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u);
if (packed_sf_k_idx < sum_packed_sf_k)
break;
if (const auto remainder = sf_k_in_group % 4; remainder > 0)
input_offset += 4 - remainder;
}
}
for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) {
// Load
uint4 values[4];
#pragma unroll
for (uint32_t j = 0; j < 4; ++ j) {
values[j] = make_uint4(0, 0, 0, 0);
if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k)
values[j] = reinterpret_cast<const uint4*>(sf + sf_k_idx * mn)[mn_idx];
}
// Pack and store
uint4 packed;
packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u);
packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u);
packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u);
packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u);
reinterpret_cast<uint4*>(out + packed_sf_k_idx * mn)[mn_idx] = packed;
}
}
} // namespace deep_gemm

View File

@@ -0,0 +1,260 @@
#pragma once
#include <cute/numeric/math.hpp>
#include <deep_gemm/common/math.cuh>
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::layout {
static constexpr int kNumCandidateBlockMs = 7;
static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192};
static constexpr int kMaxCandidateBlockM = 192;
static constexpr int kMinCandidateBlockM = 8;
static constexpr int kLCMCandidateBlockM = 384;
// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk,
T num_experts_per_rank) {
const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank;
const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank);
return math::constexpr_align(
num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast<T>(kMaxCandidateBlockM) - 1),
static_cast<T>(kLCMCandidateBlockM));
}
// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M
template <typename T>
CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) {
return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast<T>(128));
}
// Per-token source metadata for combine write-back
struct TokenSrcMetadata {
uint32_t rank_idx;
uint32_t token_idx;
uint32_t topk_idx;
};
struct Workspace {
void* base;
uint32_t num_ranks, num_experts;
uint32_t num_experts_per_rank;
uint32_t num_max_tokens_per_rank;
uint32_t num_max_recv_tokens_per_expert;
// Pool capacity: all local experts share a contiguous token pool
uint32_t num_max_pool_tokens;
uint32_t num_max_pool_blocks;
// For both grid barrier and NVLink barrier
static constexpr uint64_t kNumBarrierSignalBytes = 32;
CUTLASS_HOST_DEVICE
Workspace(void* base,
const uint32_t& num_ranks,
const uint32_t& num_experts,
const uint32_t& num_max_tokens_per_rank,
const uint32_t& num_topk):
base(base),
num_ranks(num_ranks), num_experts(num_experts),
num_max_tokens_per_rank(num_max_tokens_per_rank) {
num_experts_per_rank = num_experts / num_ranks;
num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank;
num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank);
num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM;
}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes() const {
uint64_t num_bytes = 0;
// Barrier
num_bytes += kNumBarrierSignalBytes;
// Expert send/recv count
num_bytes += num_experts * sizeof(uint64_t) * 2;
// Expert recv count sum
num_bytes += num_experts_per_rank * sizeof(uint64_t);
// L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask)
num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t);
// L2 block arrival mask
num_bytes += num_max_pool_blocks * sizeof(uint64_t);
// Dispatch pulling source token-topk
num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int);
// Combine push source indices
num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata);
// Align to TMA descriptor requirements
num_bytes = math::align<uint64_t>(num_bytes, 16);
return num_bytes;
}
CUTLASS_HOST_DEVICE
void* get_end_ptr() const {
return math::advance_ptr(base, get_num_bytes());
}
// Grid sync counters: `kNumBarrierSignalBytes` layout
// [ 0..15]: 4 x `uint32_t` grid sync counters
// [16..20]: `uint32_t` NVLink barrier counter
// [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1)
static constexpr uint32_t kNumMaxGridSyncCounters = 4;
template <uint32_t kIndex = 0>
CUTLASS_DEVICE
uint32_t* get_grid_sync_count_ptr() const {
DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds");
return static_cast<uint32_t*>(base) + kIndex;
}
CUTLASS_DEVICE
uint32_t* get_nvl_barrier_counter_ptr() const {
return static_cast<uint32_t*>(base) + kNumMaxGridSyncCounters;
}
CUTLASS_DEVICE
int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const {
// NOTES: the signal is signed, as we may minus
return math::advance_ptr<int>(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int));
}
CUTLASS_DEVICE
uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const {
return math::advance_ptr<uint64_t>(base, kNumBarrierSignalBytes) + expert_idx;
}
CUTLASS_DEVICE
uint64_t* get_expert_recv_count_ptr(
const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const {
return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx;
}
CUTLASS_DEVICE
uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const {
return get_expert_send_count_ptr(num_experts * 2) + expert_idx;
}
CUTLASS_DEVICE
uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const {
const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank);
return reinterpret_cast<uint32_t*>(base) + pool_block_idx;
}
CUTLASS_DEVICE
uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const {
// Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned
const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u));
return reinterpret_cast<uint64_t*>(base) + pool_block_idx;
}
// For dispatch pulling
CUTLASS_DEVICE
uint32_t* get_src_token_topk_idx_ptr(
const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const {
const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks);
return reinterpret_cast<uint32_t*>(base) +
expert_idx * (num_ranks * num_max_recv_tokens_per_expert) +
rank_idx * num_max_recv_tokens_per_expert + token_idx;
}
// For combine usages
CUTLASS_DEVICE
TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const {
const auto base = reinterpret_cast<TokenSrcMetadata*>(get_src_token_topk_idx_ptr(num_experts_per_rank));
return base + pool_token_idx;
}
};
struct Data {
uint32_t num_bytes;
bool require_tma_alignment;
void* base;
CUTLASS_HOST_DEVICE
constexpr explicit Data(
const uint32_t& num_bytes,
const bool& require_tma_alignment = true,
void* base = nullptr) :
num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) {
DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment);
}
template <typename dtype_t = uint32_t>
CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const {
return static_cast<dtype_t>(num_bytes);
}
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
return static_cast<dtype_t*>(base);
}
CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) {
base = ptr;
}
};
struct Buffer {
Data data_layout;
uint32_t num_ranks;
uint32_t num_max_tokens_per_rank;
void* base;
CUTLASS_HOST_DEVICE
Buffer(const Data& data_layout,
const uint32_t& num_ranks,
const uint32_t& max_num_tokens_per_rank,
void* base = nullptr) :
data_layout(data_layout),
num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank),
base(base) {}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes_per_rank() const {
return num_max_tokens_per_rank * data_layout.get_num_bytes<uint64_t>();
}
CUTLASS_HOST_DEVICE
uint64_t get_num_bytes() const {
return get_num_bytes_per_rank() * num_ranks;
}
template <typename dtype_t = void>
CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const {
return static_cast<dtype_t*>(base);
}
CUTLASS_HOST_DEVICE
void* get_end_ptr() const {
return math::advance_ptr(base, get_num_bytes());
}
CUTLASS_HOST_DEVICE
Buffer get_rank_buffer(const uint32_t& rank_idx) const {
return {
data_layout,
1, num_max_tokens_per_rank,
math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx)
};
}
CUTLASS_HOST_DEVICE
Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const {
DG_DEVICE_ASSERT(num_ranks == 1 or global);
return Data(
data_layout.num_bytes,
data_layout.require_tma_alignment,
math::advance_ptr(base, data_layout.get_num_bytes<uint64_t>() * token_idx)
);
}
};
} // namespace deep_gemm::layout

View File

@@ -0,0 +1,41 @@
#pragma once
#include <deep_gemm/common/exception.cuh>
namespace deep_gemm::layout {
constexpr static uint32_t kNumMaxRanks = 72;
template <uint32_t kNumRanks = kNumMaxRanks>
struct SymBuffer {
int64_t base;
int64_t offsets[kNumMaxRanks];
uint32_t rank_idx;
DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks");
SymBuffer() = default;
template <typename Container>
explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) {
const auto size = static_cast<uint32_t>(c.size());
base = c[rank_idx];
for (uint32_t i = 0; i < kNumMaxRanks; ++ i)
offsets[i] = i < size ? (c[i] - base) : 0;
}
#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__)
template <typename ptr_t = void*>
CUTLASS_DEVICE ptr_t get_base_ptr() const {
return reinterpret_cast<ptr_t>(base);
}
template <typename ptr_t>
CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const {
int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast<int64_t>(ptr);
return *reinterpret_cast<ptr_t*>(&mapped_ptr);
}
#endif
};
} // namespace deep_gemm::layout

Some files were not shown because too many files have changed in this diff Show More