diff --git a/third_party/DeepGEMM b/third_party/DeepGEMM deleted file mode 160000 index 714dd1a4..00000000 --- a/third_party/DeepGEMM +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 714dd1a4a980f7937a74343d19a8eba4fe321480 diff --git a/third_party/DeepGEMM/.github/workflows/_build.yml b/third_party/DeepGEMM/.github/workflows/_build.yml new file mode 100644 index 00000000..cff80136 --- /dev/null +++ b/third_party/DeepGEMM/.github/workflows/_build.yml @@ -0,0 +1,227 @@ +name: ~Build wheel template + +on: + workflow_call: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "The C++11 ABI to use for the build" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + use-local-version: + description: "Use local version" + required: false + type: boolean + default: false + +defaults: + run: + shell: bash -x -e -u -o pipefail {0} + +jobs: + build-wheel: + runs-on: ${{ inputs.runs-on }} + name: Build wheel (${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}) + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ inputs.release-version }} + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ inputs.python-version }} + + - name: Set CUDA and PyTorch versions + run: | + echo "MATRIX_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + echo "MATRIX_TORCH_VERSION=$(echo ${{ inputs.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV + echo "WHEEL_CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1'})" >> $GITHUB_ENV + echo "MATRIX_PYTHON_VERSION=$(echo ${{ inputs.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV + + - name: Free up disk space + if: ${{ runner.os == 'Linux' }} + # https://github.com/easimon/maximize-build-space/blob/master/action.yml + # https://github.com/easimon/maximize-build-space/tree/test-report + run: | + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf /opt/hostedtoolcache/CodeQL + + - name: Set up swap space + if: runner.os == 'Linux' + uses: pierotofy/set-swap-space@v1.0 + with: + swap-size-gb: 10 + + - name: Install CUDA ${{ inputs.cuda-version }} + if: ${{ inputs.cuda-version != 'cpu' }} + uses: Jimver/cuda-toolkit@v0.2.26 + id: cuda-toolkit + with: + cuda: ${{ inputs.cuda-version }} + linux-local-args: '["--toolkit"]' + # default method is "local", and we're hitting some error with caching for CUDA 11.8 and 12.1 + # method: ${{ (inputs.cuda-version == '11.8.0' || inputs.cuda-version == '12.1.0') && 'network' || 'local' }} + method: "network" + + - name: Install additional CUDA libraries + run: | + CUDA_VERSION=$(echo ${{ inputs.cuda-version }} | awk -F \. {'print $1 "-" $2'}) + sudo apt-get update + sudo apt-get install -y libcusparse-$CUDA_VERSION libcusolver-$CUDA_VERSION + sudo apt-get clean + + - name: Install PyTorch ${{ inputs.torch-version }}+cu${{ inputs.cuda-version }} + run: | + pip install --upgrade pip + # With python 3.13 and torch 2.5.1, unless we update typing-extensions, we get error + # AttributeError: attribute '__default__' of 'typing.ParamSpec' objects is not writable + pip install typing-extensions==4.12.2 + # We want to figure out the CUDA version to download pytorch + # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # This code is ugly, maybe there's a better way to do this. + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + minv = {'2.4': 118, '2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.4': 124, '2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129}[env['MATRIX_TORCH_VERSION']]; \ + print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \ + ) + if [[ ${{ inputs.torch-version }} == *"dev"* ]]; then + # pip install --no-cache-dir --pre torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + # Can't use --no-deps because we need cudnn etc. + # Hard-coding this version of pytorch-triton for torch 2.6.0.dev20241001 + pip install jinja2 + pip install https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2Bcf34004b8a-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + pip install --no-cache-dir --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ inputs.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + else + pip install --no-cache-dir torch==${{ inputs.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} + fi + nvcc --version + python --version + python -c "import torch; print('PyTorch:', torch.__version__)" + python -c "import torch; print('CUDA:', torch.version.cuda)" + python -c "from torch.utils import cpp_extension; print (cpp_extension.CUDA_HOME)" + + - name: Restore build cache + uses: actions/cache/restore@v4 + with: + path: build.tar + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + restore-keys: | + build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}- + + - name: Unpack build cache + run: | + echo ::group::Adjust timestamps + sudo find / -exec touch -t 197001010000 {} + || true + echo ::endgroup:: + + if [ -f build.tar ]; then + find . -mindepth 1 -maxdepth 1 ! -name 'build.tar' -exec rm -rf {} + + tar -xpvf build.tar -C . + else + echo "No build.tar found, skipping" + fi + + ls -al ./ + ls -al build/ || true + ls -al csrc/ || true + + - name: Build wheel + id: build_wheel + run: | + # We want setuptools >= 49.6.0 otherwise we can't compile the extension if system CUDA version is 11.7 and pytorch cuda version is 11.6 + # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/utils/cpp_extension.py#L810 + # However this still fails so I'm using a newer version of setuptools + pip install setuptools==75.8.0 + pip install ninja packaging wheel + export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH + export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + # Limit MAX_JOBS otherwise the github runner goes OOM + # nvcc 11.8 can compile with 2 jobs, but nvcc 12.3 goes OOM + + export MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "129" ] && echo 1 || echo 2) + export NVCC_THREADS=2 + export TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 9.0+PTX" + export DG_USE_LOCAL_VERSION=${{ inputs.use-local-version && '1' || '0' }} + + # 5h timeout since GH allows max 6h and we want some buffer + EXIT_CODE=0 + timeout 5h python setup.py bdist_wheel --dist-dir=dist || EXIT_CODE=$? + + if [ $EXIT_CODE -eq 0 ]; then + tmpname=cu${WHEEL_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ inputs.cxx11_abi }} + wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") + ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} + echo "wheel_name=${wheel_name}" >> $GITHUB_ENV + fi + + # Store exit code in GitHub env for later steps + echo "build_exit_code=$EXIT_CODE" | tee -a "$GITHUB_OUTPUT" + + # Do not fail the job if timeout killed the build + exit $EXIT_CODE + + - name: Log build logs after timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + run: | + ls -al ./ + tar -cvf build.tar . --atime-preserve=replace + + - name: Save build cache timeout + if: always() && steps.build_wheel.outputs.build_exit_code == 124 + uses: actions/cache/save@v4 + with: + key: build-${{ inputs.release-version }}-${{ inputs.python-version }}-${{ inputs.cuda-version }}-${{ inputs.torch-version }}-${{ inputs.cxx11_abi }}-${{ github.run_number }}-${{ github.run_attempt }} + path: build.tar + + - name: Log Built Wheels + run: | + ls dist + + - name: Get Release with tag + id: get_current_release + uses: joutvhu/get-release@v1 + with: + tag_name: ${{ inputs.release-version }} + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Upload Release Asset + id: upload_release_asset + if: inputs.upload-to-release + uses: actions/upload-release-asset@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ steps.get_current_release.outputs.upload_url }} + asset_path: ./dist/${{env.wheel_name}} + asset_name: ${{env.wheel_name}} + asset_content_type: application/* diff --git a/third_party/DeepGEMM/.github/workflows/build.yml b/third_party/DeepGEMM/.github/workflows/build.yml new file mode 100644 index 00000000..ee250aa4 --- /dev/null +++ b/third_party/DeepGEMM/.github/workflows/build.yml @@ -0,0 +1,53 @@ +name: Build wheels + +on: + workflow_dispatch: + inputs: + runs-on: + description: "The runner to use for the build" + required: true + type: string + default: ubuntu-22.04 + python-version: + description: "The Python version to use for the build" + required: true + type: string + cuda-version: + description: "The CUDA version to use for the build" + required: true + type: string + torch-version: + description: "The PyTorch version to use for the build" + required: true + type: string + cxx11_abi: + description: "Enable torch flag C++11 ABI (TRUE/FALSE)" + required: true + type: string + upload-to-release: + description: "Upload wheel to this release" + required: false + type: boolean + default: false + release-version: + description: "Upload wheel to this release" + required: false + type: string + use-local-version: + description: "Use local version" + required: false + type: boolean + default: false + +jobs: + build-wheels: + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ inputs.runs-on }} + python-version: ${{ inputs.python-version }} + cuda-version: ${{ inputs.cuda-version }} + torch-version: ${{ inputs.torch-version }} + cxx11_abi: ${{ inputs.cxx11_abi }} + upload-to-release: ${{ inputs.upload-to-release }} + release-version: ${{ inputs.release-version }} + use-local-version: ${{ inputs.use-local-version }} diff --git a/third_party/DeepGEMM/.github/workflows/publish.yml b/third_party/DeepGEMM/.github/workflows/publish.yml new file mode 100644 index 00000000..a7b3e6b8 --- /dev/null +++ b/third_party/DeepGEMM/.github/workflows/publish.yml @@ -0,0 +1,95 @@ +# This workflow will: +# - Create a new Github release +# - Build wheels for supported architectures +# - Deploy the wheels to the Github release +# - Release the static code to PyPi +# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries + +name: Build wheels and deploy + +on: + create: + tags: + - v* + +jobs: + setup_release: + name: Create Release + runs-on: ubuntu-latest + outputs: + release-version: ${{ steps.extract_branch.outputs.branch }} + steps: + - name: Get the tag version + id: extract_branch + run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/} + shell: bash + - name: Create Release + id: create_release + uses: actions/create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ steps.extract_branch.outputs.branch }} + release_name: ${{ steps.extract_branch.outputs.branch }} + + build_wheels: + name: Build Wheel + needs: setup_release + strategy: + fail-fast: false + matrix: + # Using ubuntu-22.04 instead of 24.04 for more compatibility (glibc). Ideally we'd use the + # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. + os: [ubuntu-22.04] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + torch-version: ["2.4.0", "2.5.1", "2.6.0", "2.7.1", "2.8.0"] + cuda-version: ["12.9.1"] + # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. + # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. + # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) + # when building without C++11 ABI and using it on nvcr images. + cxx11_abi: ["FALSE", "TRUE"] + exclude: + # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix + # Pytorch < 2.5 does not support Python 3.13 + - torch-version: "2.4.0" + python-version: "3.13" + uses: ./.github/workflows/_build.yml + with: + runs-on: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + cuda-version: ${{ matrix.cuda-version }} + torch-version: ${{ matrix.torch-version }} + cxx11_abi: ${{ matrix.cxx11_abi }} + release-version: ${{ needs.setup_release.outputs.release-version }} + upload-to-release: true + use-local-version: false + + publish_package: + name: Publish package + needs: [build_wheels] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + pip install ninja packaging wheel twine + # Install latest setuptools with support for pypi metadata 2.2 (improved compat w/ uv) + pip install setuptools==75.8.0 + # We don't want to download anything CUDA-related here + pip install torch --index-url https://download.pytorch.org/whl/cpu + - name: Build core package + env: + DG_USE_LOCAL_VERSION: "0" + DG_SKIP_CUDA_BUILD: "1" + run: | + python setup.py sdist --dist-dir=dist + - name: Deploy + env: + TWINE_USERNAME: "__token__" + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} + run: | + python -m twine upload dist/* diff --git a/third_party/DeepGEMM/.gitignore b/third_party/DeepGEMM/.gitignore new file mode 100644 index 00000000..d0cdf6ca --- /dev/null +++ b/third_party/DeepGEMM/.gitignore @@ -0,0 +1,24 @@ +cmake-build-* +.idea +.DS_Store +build +dist +*.egg-info +*.pyc + +# Third-party links created by `setup.py develop` +deep_gemm/include/cute +deep_gemm/include/cutlass + +# VS Code settings +/.vscode + +# clangd settings +/.clang* +/.cache + +# Generated stub files +stubs/ + +# Symlinks to compiled extensions +deep_gemm/*.so \ No newline at end of file diff --git a/third_party/DeepGEMM/.gitmodules b/third_party/DeepGEMM/.gitmodules new file mode 100644 index 00000000..332be639 --- /dev/null +++ b/third_party/DeepGEMM/.gitmodules @@ -0,0 +1,6 @@ +[submodule "third-party/cutlass"] + path = third-party/cutlass + url = https://github.com/NVIDIA/cutlass.git +[submodule "third-party/fmt"] + path = third-party/fmt + url = https://github.com/fmtlib/fmt.git diff --git a/third_party/DeepGEMM/CMakeLists.txt b/third_party/DeepGEMM/CMakeLists.txt new file mode 100644 index 00000000..bbf625d3 --- /dev/null +++ b/third_party/DeepGEMM/CMakeLists.txt @@ -0,0 +1,32 @@ +# NOTES: current just for CMake-based IDE (e.g. CLion) indexing, the real compilation is done via JIT +cmake_minimum_required(VERSION 3.10) +project(deep_gemm LANGUAGES CXX CUDA) +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -Wno-psabi -Wno-deprecated-declarations") +set(CUDA_SEPARABLE_COMPILATION ON) +list(APPEND CUDA_NVCC_FLAGS "-DENABLE_FAST_DEBUG") +list(APPEND CUDA_NVCC_FLAGS "-O3") +list(APPEND CUDA_NVCC_FLAGS "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage") + +set(USE_SYSTEM_NVTX on) +set(CUDA_ARCH_LIST "9.0" CACHE STRING "List of CUDA architectures to compile") +set(TORCH_CUDA_ARCH_LIST "${CUDA_ARCH_LIST}") + +find_package(CUDAToolkit REQUIRED) +find_package(pybind11 REQUIRED) +find_package(Torch REQUIRED) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CUDA_STANDARD 20) + +include_directories(deep_gemm/include third-party/cutlass/include third-party/cutlass/tools/util/include third-party/fmt/include) +include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include/cccl ${TORCH_INCLUDE_DIRS} ${PYTHON_INCLUDE_DIRS}) +link_directories(${TORCH_INSTALL_PREFIX}/lib ${CUDA_TOOLKIT_ROOT_DIR}/lib64 ${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) + +# The main Python API entrance +pybind11_add_module(_C csrc/python_api.cpp) +target_link_libraries(_C PRIVATE ${TORCH_LIBRARIES} torch_python) + +# Enable kernel code indexing with CMake-based IDEs +cuda_add_library(deep_gemm_indexing_cuda STATIC csrc/indexing/main.cu) diff --git a/third_party/DeepGEMM/LICENSE b/third_party/DeepGEMM/LICENSE new file mode 100644 index 00000000..5c48bdc9 --- /dev/null +++ b/third_party/DeepGEMM/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 DeepSeek + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/third_party/DeepGEMM/README.md b/third_party/DeepGEMM/README.md new file mode 100644 index 00000000..6ef705ff --- /dev/null +++ b/third_party/DeepGEMM/README.md @@ -0,0 +1,207 @@ +# DeepGEMM + +DeepGEMM is a unified, high-performance tensor core kernel library that brings together the key computation primitives of modern large language models — GEMMs (FP8, FP4, BF16), fused MoE with overlapped communication (Mega MoE), MQA scoring for the lightning indexer, HyperConnection (HC), and more — into a single, cohesive CUDA codebase. All kernels are compiled at runtime via a lightweight Just-In-Time (JIT) module, requiring no CUDA compilation during installation. + +DeepGEMM leverages some concepts from [CUTLASS](https://github.com/nvidia/cutlass) and [CuTe](https://github.com/NVIDIA/cutlass/tree/main/include/cute), but avoids heavy reliance on their templates or algebras. The library is designed for simplicity, with only a limited number of core kernel functions, making it a clean and accessible resource for learning NVIDIA GPU kernel optimization techniques. + +Despite its lightweight design, DeepGEMM's performance matches or exceeds expert-tuned libraries across various matrix shapes. + +## News + +- 2026.04.16: Mega MoE, FP8xFP4 GEMM, FP4 Indexer, PDL, faster JIT compilation and more. + - Please see [#304](https://github.com/deepseek-ai/DeepGEMM/pull/304) for more details. + - For Mega MoE benchmarks, refer to [#316](https://github.com/deepseek-ai/DeepGEMM/pull/316). +- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2. + - Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details. +- 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module. + - NVRTC and post-compilation SASS optimization are all disabled. + - NVRTC will be supported later. + - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported. + - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details. +- 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. +- 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). +- 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. + +## Quick start + +### Requirements + +- NVIDIA SM90 or SM100 architecture GPU +- Python 3.8 or higher +- Compilers with C++20 support +- CUDA Toolkit: + - CUDA 12.3 or higher for SM90 + - **We highly recommend 12.9 or higher for the best performance** + - CUDA 12.9 or higher for SM100 +- PyTorch 2.1 or higher +- CUTLASS 4.0 or higher (could be cloned by Git submodule) +- `{fmt}` library (could be cloned by Git submodule) + +### Development + +```bash +# Submodule must be cloned +git clone --recursive git@github.com:deepseek-ai/DeepGEMM.git +cd DeepGEMM + +# Link some essential includes and build the CPP JIT module +cat develop.sh +./develop.sh +``` + +### Installation + +```bash +cat install.sh +./install.sh +``` + +Then, import `deep_gemm` in your Python project, and enjoy! + +## Interfaces + +#### Notices + +This library provides optimized GEMM kernels for NVIDIA GPUs with a naming convention: `D = C + A @ B`. The input shape layout is NT (non-transposed A, transposed B). While the SM90 implementation supports only the NT memory layout (row-major, col-major), the SM100 implementation supports all memory layouts (NT, TN, NN, TT). For example, `fp8_gemm_nt` will do a `D = C + A @ B.T` + +For both architectures, the LHS scaling factor is required to have a TMA-aligned and transposed layout. And the data format for the scaling factor of SM90 and SM100 is different: + +- SM90 requires scaling factors in FP32 format. +- SM100 requires scaling factors in packed [UE8M0](https://docs.nvidia.com/cuda/parallel-thread-execution/#alternate-floating-point-data-formats) format, which packs 4 UE8M0 into a single `torch.int`. + +Please note that operations like input transposition or FP8 casting must be handled separately by the user, please implement or fuse them into prior kernels independently. While the library provides some simple PyTorch utility functions, these may result in slower performance, but our primary focus is on optimizing the GEMM kernels themselves. + +#### Normal dense GEMMs (non-grouped) + +To perform a basic non-grouped FP8 GEMM, call the `fp8_gemm_{nt, nn, tn, tt}` function. For more details, please refer to the function documentation. + +#### Grouped GEMMs (contiguous layout) + +Unlike traditional grouped GEMMs in CUTLASS, DeepGEMM groups only the M-axis, while N and K must remain fixed. This design is tailored for scenarios where experts in an MoE model share the same shape. For training forward passes or inference prefilling, where each expert may process a varying number of tokens, we concatenate these tokens into a single tensor, referred to as the "contiguous" layout. Note that each expert segment must be aligned to the GEMM M block size (`get_mk_alignment_for_contiguous_layout()`). For more information, please refer to the `m_grouped_fp8_gemm_{nt, nn}_contiguous` function documentation. + +We also provide a K-axis-grouped API for MoE weight backward (with M and N must remain fixed), please refer to `k_grouped_fp8_gemm_tn_contiguous` for more information. + +#### Grouped GEMMs (masked layout) + +During the inference decoding phase, when CUDA graph is enabled and the CPU is unaware of the number of tokens each expert receives, we support masked grouped GEMMs. By providing a mask tensor, the kernel computes only the valid portions. + +Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. + +#### V3.2 MQA kernels for the indexer + +The kernel family has two versions, non-paged (for prefilling) and paged (for decoding). +Take the non-paged version `fp8_mqa_logits` as an example. It has 6 inputs: + +- `q`, E4M3 tensor with shape `[seq_len, num_heads, head_dim]` +- `kv`, E4M3 tensor (shaped as `[seq_len_kv, head_dim]`) with float SF (shaped as `[seq_len_kv]`) +- `weights`, float tensor with shape `[seq_len, num_heads]` +- `cu_seq_len_k_start` and `cu_seq_len_k_end`, int tensor with shape `[seq_len]` +- `clean_logits`, whether to clean the unfilled logits into `-inf` + +The output tensor is shaped as `[seq_len, seq_len_kv]`, indicating token-to-token logits. +For each token `i` in `q`, it will iterate all tokens `j` from `[cu_seq_len_k_start[i], cu_seq_len_k_end[i])`, +and calculate the logit `out[i, j]` as: + +```python +kv_j = kv[0][j, :] * kv[1][j].unsqueeze(1) # [head_dim] +out_ij = q[i, :, :] @ kv_j # [num_heads] +out_ij = out_ij.relu() * weights[i, :] # [num_heads] +out_ij = out_ij.sum() # Scalar +``` + +For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`. + +#### Mega MoE + +Mega MoE fuses and overlaps EP dispatch, linear 1 (FP8xFP4), SwiGLU, linear 2 (FP8xFP4), and EP combine into a single mega-kernel, overlapping NVLink communication and tensor core computation. It requires multi-process launch with symmetric memory. Usage: + +```python +# Allocate symmetric memory buffer +# NOTES: requires PyTorch >= 2.9 +buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, num_max_tokens_per_rank, num_topk, hidden, intermediate_hidden +) + +# Transform weights (FP4 with UE8M0 SF) into the required layout +transformed_l1, transformed_l2 = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights) + +# Copy inputs into the buffer before each call +# You may fuse these into previous kernels +buffer.x[:num_tokens].copy_(x_fp8) +buffer.x_sf[:num_tokens].copy_(x_sf) +buffer.topk_idx[:num_tokens].copy_(topk_idx) +buffer.topk_weights[:num_tokens].copy_(topk_weights) + +# Run the fused mega MoE kernel +y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') +deep_gemm.fp8_fp4_mega_moe(y, transformed_l1, transformed_l2, buffer) +``` + +For the full example with multi-process setup and benchmarking, please refer to `tests/test_mega_moe.py`. + +#### Utilities + +The library provides some utility functions besides the above kernels: + +- `deep_gemm.set_num_sms` / `get_num_sms`: set/get the maximum SM count to use +- `deep_gemm.set_tc_util` / `get_tc_util`: set/get an approximated tensor core utilization ratio +- `deep_gemm.set_pdl` / `get_pdl`: enable/disable Programmatic Dependent Launch (PDL) +- `deep_gemm.set_mk_alignment_for_contiguous_layout` / `get_mk_alignment_for_contiguous_layout`: set/get the group-level M/K alignment for contiguous layout +- `deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout`: get the theoretical minimum M/K alignment +- `deep_gemm.set_ignore_compile_dims`: configure dimensions to ignore during JIT compilation +- `deep_gemm.set_block_size_multiple_of`: constrain block sizes to be multiples of a given value +- `deep_gemm.transform_sf_into_required_layout`: transform scaling factors into the required layout +- `deep_gemm.get_tma_aligned_size`: get the required TMA alignment size +- `deep_gemm.get_mn_major_tma_aligned_tensor`: get a MN-major TMA-aligned tensor +- `deep_gemm.get_mn_major_tma_aligned_packed_ue8m0_tensor`: get a MN-major TMA-aligned tensor (with packing FP32 into UE8M0) +- `deep_gemm.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor`: K-grouped GEMM packing kernel + +The library also provides some environment variables, which may be useful: + +- General + - `DG_JIT_DEBUG`: `0` or `1`, print JIT debugging information, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default +- JIT cache + - `DG_JIT_CACHE_DIR`: string, cache directory for compiled kernels, `$HOME/.deep_gemm` by default +- Compiler selection + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC (faster compilation, may have lower performance for some cases), `0` by default + - `DG_JIT_NVCC_COMPILER`: string, NVCC compiler path; defaults to `torch.utils.cpp_extension.CUDA_HOME` + - `DG_JIT_CPP_STANDARD`: integer, C++ standard version, `20` by default +- Compiler output + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print compilation commands, `0` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS output, `0` by default + - `DG_JIT_PTXAS_CHECK`: `0` or `1`, assert no local memory usage in compiled kernels, `0` by default + - `DG_JIT_PRINT_LOAD_TIME`: `0` or `1`, print kernel load time, `0` by default +- Debug and profiling + - `DG_JIT_WITH_LINEINFO`: `0` or `1`, embed source line info for profiling tools, `0` by default + - `DG_JIT_DUMP_ASM`: `0` or `1`, dump both PTX and SASS, `0` by default + - `DG_JIT_DUMP_PTX`: `0` or `1`, dump PTX output, `0` by default + - `DG_JIT_DUMP_SASS`: `0` or `1`, dump SASS output, `0` by default + - `DG_COMM_KERNEL_DEBUG`: `0` or `1`, zero symmetric buffer before each Mega MoE call for debugging, `0` by default + - `DG_USE_NVIDIA_TOOLS`: `0` or `1`, skip internal profiling when running under external NVIDIA tools, `0` by default +- Build options + - `DG_SKIP_CUDA_BUILD`: `0` or `1`, skip CUDA extension build during installation, `0` by default + - `DG_FORCE_BUILD`: `0` or `1`, force local build instead of downloading pre-built wheels, `0` by default + - `DG_JIT_USE_RUNTIME_API`: `0` or `1`, use CUDA Runtime API for kernel loading (requires CUDA runtime >= 12.8), `0` by default + +For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation. + +## Acknowledgement + +DeepGEMM is inspired by the [CUTLASS](https://github.com/nvidia/cutlass) project. Thanks and respect to the developers! + +## License + +This code repository is released under [the MIT License](LICENSE). + +## Citation + +```bibtex +@misc{deepgemm2025, + title={DeepGEMM: clean and efficient BLAS kernel library on GPU}, + author={Chenggang Zhao and Zhean Xu and Liang Zhao and Jiashi Li and Chenhao Xu and Anyi Xu and Shengyu Liu and Kexing Zhou and Kuai Yu}, + year={2025}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/deepseek-ai/DeepGEMM}}, +} +``` diff --git a/third_party/DeepGEMM/build.sh b/third_party/DeepGEMM/build.sh new file mode 100755 index 00000000..abdfc406 --- /dev/null +++ b/third_party/DeepGEMM/build.sh @@ -0,0 +1,12 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build files, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel + +# Open users' original directory +cd "$original_dir" diff --git a/third_party/DeepGEMM/csrc/apis/attention.hpp b/third_party/DeepGEMM/csrc/apis/attention.hpp new file mode 100644 index 00000000..505b0c09 --- /dev/null +++ b/third_party/DeepGEMM/csrc/apis/attention.hpp @@ -0,0 +1,453 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp" +#include "../jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp" +#include "../jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp" +#include "../jit_kernels/impls/smxx_clean_logits.hpp" +#endif + +#include "layout.hpp" + +namespace deep_gemm::attention { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void fp8_gemm_nt_skip_head_mid(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::tuple& head_splits, + std::optional> recipe, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto [m , k ] = get_shape<2>(a.first); + const auto [n , k_] = get_shape<2>(b.first); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Check head splits and N + const auto [left, mid, right] = head_splits; + DG_HOST_ASSERT(n % (left + right) == 0 and n_ == n + n / (left + right) * mid); + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, disable_ue8m0_cast); + DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128); + + // Dispatch into different implements + const auto arch_major = device_runtime->get_arch_major(); + const auto epilogue_type = fmt::format("epilogue::transform::EpilogueHeadSplits<{}, {}, {}>", left, mid, right); + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat and std::get<1>(recipe.value()) != 1) { + const auto major_sfb = get_major_type_ab(sfb); + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, major_a, major_b, major_sfb, compiled_dims, epilogue_type); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + // NOTES: Only granularity 128 and FP8 are exposed in the API + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, std::nullopt, d, m, n, k, + 128, 128, major_a, major_b, compiled_dims, epilogue_type); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static torch::Tensor fp8_fp4_mqa_logits(const std::tuple>& q, + const std::tuple& kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const bool& clean_logits, + const int& max_seqlen_k, + const at::ScalarType& logits_dtype) { + const auto [q_fp, q_sf] = q; + const auto [kv_fp, kv_sf] = kv; + const bool is_fp4 = q_sf.has_value(); + int seq_len, seq_len_kv, num_heads, head_dim; + + if (is_fp4) { + // Check FP4 Q + std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp); + head_dim *= 2; + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4); + + // Check SF Q + auto [_seq_len, _num_heads] = get_shape<2>(q_sf.value()); + DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads); + DG_HOST_ASSERT(q_sf.value().is_contiguous()); + DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32); + + // Check FP4 KV + int _head_dim; + std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp); + _head_dim *= 2; + DG_HOST_ASSERT(head_dim == _head_dim); + DG_HOST_ASSERT(kv_fp.is_contiguous()); + DG_HOST_ASSERT(kv_fp.scalar_type() == kPackedFP4); + + // Check SF KV + auto [_seq_len_kv] = get_shape<1>(kv_sf); + DG_HOST_ASSERT(seq_len_kv == _seq_len_kv); + DG_HOST_ASSERT(kv_sf.is_contiguous()); + DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kInt32); + } else { + // Check FP8 Q + std::tie(seq_len, num_heads, head_dim) = get_shape<3>(q_fp); + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn); + + // Check FP4 KV + int _head_dim; + std::tie(seq_len_kv, _head_dim) = get_shape<2>(kv_fp); + DG_HOST_ASSERT(head_dim == _head_dim); + DG_HOST_ASSERT(kv_fp.is_contiguous()); + DG_HOST_ASSERT(kv_fp.scalar_type() == torch::kFloat8_e4m3fn); + + // Check SF KV + auto [_seq_len_kv] = get_shape<1>(kv_sf); + DG_HOST_ASSERT(seq_len_kv == _seq_len_kv); + DG_HOST_ASSERT(kv_sf.is_contiguous()); + DG_HOST_ASSERT(kv_sf.scalar_type() == torch::kFloat); + } + + // Check weights + auto [_seq_len, _num_heads] = get_shape<2>(weights); + DG_HOST_ASSERT(seq_len == _seq_len and num_heads == _num_heads); + DG_HOST_ASSERT(weights.stride(1) == 1); + DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); + + // Check cu_seq_len_k_start + DG_HOST_ASSERT(cu_seq_len_k_start.size(0) == seq_len); + DG_HOST_ASSERT(cu_seq_len_k_start.is_contiguous()); + DG_HOST_ASSERT(cu_seq_len_k_start.scalar_type() == torch::kInt); + + // Check cu_seq_len_k_end + DG_HOST_ASSERT(cu_seq_len_k_end.size(0) == seq_len); + DG_HOST_ASSERT(cu_seq_len_k_end.is_contiguous()); + DG_HOST_ASSERT(cu_seq_len_k_end.scalar_type() == torch::kInt); + + // Allocate output + constexpr int block_qh = 128; + constexpr int block_kv = 256; + const int block_q = block_qh / num_heads; + DG_HOST_ASSERT(block_qh % num_heads == 0); + + torch::Tensor logits; + int aligned_seq_len = align(seq_len, block_q), stride_logits; + if (max_seqlen_k == 0) { + // Logits stride must be 16-byte aligned + stride_logits = align(seq_len_kv + block_kv, 8); + logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype)); + logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, seq_len_kv)}); + } else { + stride_logits = align(max_seqlen_k, block_kv); + logits = torch::empty({aligned_seq_len, stride_logits}, q_fp.options().dtype(logits_dtype)); + logits = logits.index({torch::indexing::Slice(0, seq_len), torch::indexing::Slice(0, max_seqlen_k)}); + DG_HOST_ASSERT(not clean_logits); + } + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (is_fp4 and arch_major == 10) { + sm100_fp4_mqa_logits(q_fp, q_sf.value(), kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype, + seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv); + } else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) { + smxx_fp8_mqa_logits(q_fp, kv_fp, kv_sf, weights, cu_seq_len_k_start, cu_seq_len_k_end, logits, logits_dtype, + seq_len, seq_len_kv, max_seqlen_k, stride_logits, num_heads, head_dim, block_q, block_kv); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Clean unfilled logits + if (clean_logits) + smxx_clean_logits(logits, cu_seq_len_k_start, cu_seq_len_k_end, 1, seq_len, seq_len_kv, stride_logits); + return logits; +} + +static torch::Tensor get_paged_mqa_logits_metadata(const torch::Tensor& context_lens, int block_kv, int num_sms, const std::optional& indices) { + // NOTES: Only 2D context lens is supported for now + DG_HOST_ASSERT(context_lens.dim() == 2); + const bool is_context_lens_2d = true; + const int batch_size = context_lens.size(0); + const int next_n = context_lens.size(1); + const bool is_varlen = indices.has_value(); + DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); + DG_HOST_ASSERT(context_lens.is_contiguous()); + + // Create metadata tensor + auto schedule_metadata = torch::empty({num_sms + 1, 2}, context_lens.options()); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (is_varlen) { + const auto& indices_tensor = indices.value(); + DG_HOST_ASSERT(arch_major == 10 and next_n == 1 and (block_kv == 64 or block_kv == 32)); + DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size); + DG_HOST_ASSERT(indices_tensor.is_contiguous()); + DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt); + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, true, indices_tensor.data_ptr()); + } else if (arch_major == 9 or arch_major == 10) { + DG_HOST_ASSERT(block_kv == 64 or (arch_major == 10 and block_kv == 32)); + smxx_paged_mqa_logits_metadata(context_lens, schedule_metadata, batch_size, next_n, block_kv, num_sms, is_context_lens_2d, false, nullptr); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + return schedule_metadata; +} + +static torch::Tensor fp8_fp4_paged_mqa_logits(const std::tuple>& q, + const torch::Tensor& fused_kv_cache, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& max_context_len, + const bool& clean_logits, + const at::ScalarType& logits_dtype, + const std::optional& indices) { + const auto [q_fp, q_sf] = q; + const bool is_fp4 = q_sf.has_value(); + + torch::Tensor kv_cache, kv_cache_sf; + int batch_size, next_n, num_heads, head_dim; + int num_kv_blocks, block_kv; + int kv_cache_stride_bytes; + int block_table_stride = block_table.stride(0); + int num_sms = device_runtime->get_num_sms(); + + if (is_fp4) { + // Check FP4 Q + std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp); + head_dim *= 2; + DG_HOST_ASSERT(next_n >= 1); + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == kPackedFP4); + + // Check SF Q + auto [_batch_size, _next_n, _num_heads] = get_shape<3>(q_sf.value()); + DG_HOST_ASSERT(batch_size == _batch_size and next_n == _next_n and num_heads == _num_heads); + DG_HOST_ASSERT(q_sf.value().is_contiguous()); + DG_HOST_ASSERT(q_sf.value().scalar_type() == torch::kInt32); + + // Check fused KV cache + int num_heads_kv, fp4_with_sf_bytes; + std::tie(num_kv_blocks, block_kv, num_heads_kv, fp4_with_sf_bytes) = get_shape<4>(fused_kv_cache); + DG_HOST_ASSERT(block_kv == 32 or block_kv == 64); + DG_HOST_ASSERT(num_heads_kv == 1 and fp4_with_sf_bytes == head_dim / 2 + static_cast(sizeof(int))); + DG_HOST_ASSERT(fused_kv_cache.stride(1) == fp4_with_sf_bytes and fused_kv_cache.stride(3) == 1); + DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte); + + // Derive FP4 values and SF tensor + kv_cache_stride_bytes = fused_kv_cache.stride(0); + DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(int) == 0); + kv_cache = torch::from_blob( + fused_kv_cache.data_ptr(), + {num_kv_blocks, block_kv, head_dim / 2}, + {kv_cache_stride_bytes, head_dim / 2, 1}, + torch::TensorOptions().dtype(kPackedFP4) + ); + kv_cache_sf = torch::from_blob( + fused_kv_cache.data_ptr() + block_kv * head_dim / 2, + {num_kv_blocks, block_kv}, + {kv_cache_stride_bytes / static_cast(sizeof(int)), 1}, + torch::TensorOptions().dtype(torch::kInt32) + ); + } else { + // Check FP8 Q + std::tie(batch_size, next_n, num_heads, head_dim) = get_shape<4>(q_fp); + DG_HOST_ASSERT(next_n >= 1); + DG_HOST_ASSERT(num_heads == 32 or num_heads == 64); + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + DG_HOST_ASSERT(q_fp.is_contiguous()); + DG_HOST_ASSERT(q_fp.scalar_type() == torch::kFloat8_e4m3fn); + + // Check fused KV cache + int num_heads_kv, head_dim_with_sf; + std::tie(num_kv_blocks, block_kv, num_heads_kv, head_dim_with_sf) = get_shape<4>(fused_kv_cache); + DG_HOST_ASSERT(block_kv == 32 or block_kv == 64); + DG_HOST_ASSERT(num_heads_kv == 1 and head_dim_with_sf == head_dim + static_cast(sizeof(float))); + DG_HOST_ASSERT(fused_kv_cache.stride(1) == head_dim_with_sf and fused_kv_cache.stride(3) == 1); + DG_HOST_ASSERT(fused_kv_cache.scalar_type() == torch::kByte); + + // Derive FP8 values and SF tensor + kv_cache_stride_bytes = fused_kv_cache.stride(0); + DG_HOST_ASSERT(kv_cache_stride_bytes % sizeof(float) == 0); + kv_cache = torch::from_blob( + fused_kv_cache.data_ptr(), + {num_kv_blocks, block_kv, head_dim}, + {kv_cache_stride_bytes, head_dim, 1}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn) + ); + kv_cache_sf = torch::from_blob( + fused_kv_cache.data_ptr() + block_kv * head_dim, + {num_kv_blocks, block_kv}, + {kv_cache_stride_bytes / static_cast(sizeof(float)), 1}, + torch::TensorOptions().dtype(torch::kFloat32) + ); + + // Weights must be contiguous for FP8 + DG_HOST_ASSERT(weights.is_contiguous()); + } + + // Check weights + auto [_batch_size_next_n, _num_heads] = get_shape<2>(weights); + DG_HOST_ASSERT(_batch_size_next_n == batch_size * next_n and _num_heads == num_heads); + DG_HOST_ASSERT(weights.stride(1) == 1); + DG_HOST_ASSERT(weights.scalar_type() == torch::kFloat); + + // Check block table + auto [_batch_size, _max_block_len] = get_shape<2>(block_table); + DG_HOST_ASSERT(_batch_size == batch_size); + DG_HOST_ASSERT(block_table.stride(1) == 1); + DG_HOST_ASSERT(block_table.scalar_type() == torch::kInt); + + // Check indices + const bool is_varlen = indices.has_value(); + const auto arch_major = device_runtime->get_arch_major(); + const auto indices_tensor = indices.value_or(torch::Tensor()); + if (is_varlen) { + DG_HOST_ASSERT(arch_major == 10 and next_n == 1); + DG_HOST_ASSERT(indices_tensor.dim() == 1 and indices_tensor.size(0) == batch_size); + DG_HOST_ASSERT(indices_tensor.is_contiguous()); + DG_HOST_ASSERT(indices_tensor.scalar_type() == torch::kInt); + } + + // Check schedule metadata + auto [_schedule_meta_size, _meta_info_size] = get_shape<2>(schedule_meta); + DG_HOST_ASSERT(_schedule_meta_size == num_sms + 1 and _meta_info_size == 2); + DG_HOST_ASSERT(schedule_meta.is_contiguous()); + DG_HOST_ASSERT(schedule_meta.scalar_type() == torch::kInt); + + // Check context lengths + // NOTES: Only 2D context lens is supported for now + DG_HOST_ASSERT(context_lens.dim() == 2); + const bool is_context_lens_2d = true; + const auto [__batch_size, _next_n] = get_shape<2>(context_lens); + DG_HOST_ASSERT(batch_size == __batch_size and next_n == _next_n); + DG_HOST_ASSERT(context_lens.is_contiguous()); + DG_HOST_ASSERT(context_lens.scalar_type() == torch::kInt); + + // Allocate output + constexpr int split_kv = 256; + const auto aligned_max_context_len = align(max_context_len, split_kv); + auto logits = torch::empty({batch_size * next_n, aligned_max_context_len}, q_fp.options().dtype(logits_dtype)); + logits = logits.slice(-1, 0, max_context_len); + DG_HOST_ASSERT(logits_dtype == torch::kFloat32 or logits_dtype == torch::kBFloat16); + + // Dispatch implementation + if (is_fp4 and arch_major == 10) { + sm100_fp4_paged_mqa_logits(q_fp, q_sf.value(), kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta, + logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, + is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv); + } else if (not is_fp4 and (arch_major == 9 or arch_major == 10)) { + smxx_fp8_paged_mqa_logits(q_fp, kv_cache, kv_cache_sf, weights, context_lens, logits, block_table, indices_tensor, schedule_meta, + logits_dtype, batch_size, next_n, num_heads, head_dim, num_kv_blocks, block_kv, is_context_lens_2d, + is_varlen, aligned_max_context_len, block_table_stride, num_sms, split_kv); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Clean unfilled logits + if (clean_logits) { + DG_HOST_ASSERT(not is_context_lens_2d); + smxx_clean_logits(logits, std::nullopt, context_lens, next_n, batch_size * next_n, max_context_len, aligned_max_context_len); + } + return logits; +} + + +// Legacy API wrappers +static torch::Tensor fp8_mqa_logits(const torch::Tensor& q, + const std::tuple& kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const bool& clean_logits, + const int& max_seqlen_k) { + return fp8_fp4_mqa_logits(std::make_tuple(q, std::nullopt), kv, weights, + cu_seq_len_k_start, cu_seq_len_k_end, + clean_logits, max_seqlen_k, torch::kFloat); +} + +static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& fused_kv_cache, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& max_context_len, + const bool& clean_logits, + const std::optional& indices) { + return fp8_fp4_paged_mqa_logits(std::make_tuple(q, std::nullopt), fused_kv_cache, weights, + context_lens, block_table, schedule_meta, + max_context_len, clean_logits, torch::kFloat, indices); +} +#endif + +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"), + py::arg("recipe") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_mqa_logits", &fp8_fp4_mqa_logits, + py::arg("q"), py::arg("kv"), py::arg("weights"), + py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"), + py::arg("clean_logits") = true, + py::arg("max_seqlen_k") = 0, + py::arg("logits_dtype") = torch::kFloat32); + m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata, + py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"), + py::arg("indices") = std::nullopt); + m.def("fp8_fp4_paged_mqa_logits", &fp8_fp4_paged_mqa_logits, + py::arg("q"), py::arg("kv_cache"), py::arg("weights"), + py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), + py::arg("max_context_len"), + py::arg("clean_logits") = false, + py::arg("logits_dtype") = torch::kFloat32, + py::arg("indices") = std::nullopt); + // Legacy API + m.def("fp8_mqa_logits", &fp8_mqa_logits, + py::arg("q"), py::arg("kv"), py::arg("weights"), + py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"), + py::arg("clean_logits") = true, + py::arg("max_seqlen_k") = 0); + m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits, + py::arg("q"), py::arg("kv_cache"), py::arg("weights"), + py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"), + py::arg("max_context_len"), py::arg("clean_logits") = false, + py::arg("indices") = std::nullopt); +#endif +} + +} // namespace deep_gemm::attention diff --git a/third_party/DeepGEMM/csrc/apis/einsum.hpp b/third_party/DeepGEMM/csrc/apis/einsum.hpp new file mode 100644 index 00000000..f82331ca --- /dev/null +++ b/third_party/DeepGEMM/csrc/apis/einsum.hpp @@ -0,0 +1,231 @@ +#pragma once + +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/layout.hpp" +#include "../utils/compatibility.hpp" +#include "gemm.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_bmk_bnk_mn.hpp" +#include "../jit_kernels/impls/sm100_bmk_bnk_mn.hpp" +#include "../jit_kernels/impls/sm90_bf16_gemm.hpp" +#include "../jit_kernels/impls/sm100_bf16_gemm.hpp" +#include "../jit_kernels/impls/smxx_cublaslt.hpp" +#endif + +namespace deep_gemm::einsum { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void bmk_bnk_mn(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, + const std::optional& c) { + // Currently FP32 only support the accumulated expression + if (d.scalar_type() == torch::kFloat) { + DG_HOST_ASSERT(c->data_ptr() == d.data_ptr() and c->sizes() == d.sizes() and c->strides() == d.strides()); + } else { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(not c.has_value()); + + const auto workspace = torch::empty_like(d, d.options().dtype(torch::kFloat32)); + DG_CUDA_RUNTIME_CHECK(cudaMemsetAsync(workspace.data_ptr(), 0, workspace.nbytes(), + c10::cuda::getCurrentCUDAStream())); + bmk_bnk_mn(a, b, workspace, workspace); + + // This line has an implicit FP32-to-BF16 casting + d.copy_(workspace); + return; + } + + DG_HOST_ASSERT(a.is_contiguous()); + DG_HOST_ASSERT(b.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + + const auto [s , m, k ] = get_shape<3>(a); + const auto [s_, n, k_] = get_shape<3>(b); + DG_HOST_ASSERT(s == s_ and k == k_); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bmn_bnk_mn_gemm(a, b, d, s, m, n, k); + } else if (arch_major == 10) { + sm100_bmn_bnk_mn_gemm(a, b, d, s, m, n, k); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bhr_hdr_bhd(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { + const auto [b , h , r ] = get_shape<3>(A); + const auto [h_, d , r_] = get_shape<3>(B); + const auto [b_, h__, d_] = get_shape<3>(D); + DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); + + DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); + DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); + DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (use_cublaslt) { + cublaslt_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else if (arch_major == 9) { + sm90_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else if (arch_major == 10) { + sm100_bf16_bhr_hdr_bhd(A, B, D, b, h, r, d); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bhd_hdr_bhr(const torch::Tensor& A, const torch::Tensor& B, const torch::Tensor& D, const bool& use_cublaslt) { + const auto [b , h , d ] = get_shape<3>(A); + const auto [h_, d_ , r ] = get_shape<3>(B); + const auto [b_, h__, r_] = get_shape<3>(D); + DG_HOST_ASSERT(b == b_ and h == h_ and r == r_ and d == d_ and h == h__); + + DG_HOST_ASSERT(A.scalar_type() == torch::kBFloat16 and A.stride(2) == 1); + DG_HOST_ASSERT(B.scalar_type() == torch::kBFloat16 and B.stride(2) == 1); + DG_HOST_ASSERT(D.scalar_type() == torch::kBFloat16 and D.stride(2) == 1); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (use_cublaslt) { + cublaslt_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else if (arch_major == 9) { + sm90_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else if (arch_major == 10) { + sm100_bf16_bhd_hdr_bhr(A, B, D, b, h, r, d); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void einsum(const std::string& expr, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const bool& use_cublaslt) { + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + if (c.has_value()) { + DG_HOST_ASSERT(c->scalar_type() == torch::kFloat); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + } + + // Some hardcoded Einstein sum kernels + // TODO: support any expression + // TODO: canonicalize expression + if (expr == "bmk,bnk->mn") { + DG_HOST_ASSERT(not use_cublaslt); + bmk_bnk_mn(a, b, d, c); + } else if (expr == "bhr,hdr->bhd") { + DG_HOST_ASSERT(not c.has_value()); + bhr_hdr_bhd(a, b, d, use_cublaslt); + } else if (expr == "bhd,hdr->bhr") { + DG_HOST_ASSERT(not c.has_value()); + bhd_hdr_bhr(a, b, d, use_cublaslt); + } else { + DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); + } +} + +static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + const std::string& compiled_dims) { + // Shape must be `[B, M, K] @ [B, N, K].T` + const auto major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + const auto major_b = b.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; + DG_HOST_ASSERT(a.stride(-1) == 1 or a.stride(-2) == 1); + DG_HOST_ASSERT(b.stride(-1) == 1 or b.stride(-2) == 1); + DG_HOST_ASSERT(d.stride(-1) == 1); + + // Type and shape checks + const auto [batch_size , m , k ] = get_shape<3>(a); + const auto [batch_size_ , n , k_] = get_shape<3>(b); + const auto [batch_size__, m_, n_] = get_shape<3>(d); + DG_HOST_ASSERT(batch_size == batch_size_ and batch_size == batch_size_); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(a.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat8_e4m3fn); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Early return for trivial cases + if (batch_size == 0 or gemm::early_return(m, n, k, d, c)) + return; + + // Transform scaling factors + const auto [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, gran_k_a, gran_k_b, major_a, major_b, compiled_dims); + } else { + const auto major_sfb = get_major_type_ab(sfb); + DG_HOST_ASSERT(gran_k_a == 128 and gran_k_b == 128); + sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims); + } +} + +static void fp8_einsum(const std::string& expr, + const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::tuple& recipe) { + // Some hardcoded Einstein sum kernels + const auto arch_major = device_runtime->get_arch_major(); + if (expr == "bhr,hdr->bhd") { + // Permute dims to satisfy the order of (batch_size, m, n, k) + // (batch_size, m, n, k): (h, b, d, r) + const auto perm_a = a.first.permute({1, 0, 2}); + const auto perm_sfa = a.second.permute({1, 0, 2}); + const auto perm_d = d.permute({1, 0, 2}); + const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk"); + } else if (expr == "bhd,hdr->bhr" and arch_major == 10) { + // (batch_size, m, n, k): (h, b, r, d) + const auto perm_a = a.first.permute({1, 0, 2}); + const auto perm_sfa = a.second.permute({1, 0, 2}); + const auto perm_b = b.first.permute({0, 2, 1}); + const auto perm_sfb = b.second.permute({0, 2, 1}); + const auto perm_d = d.permute({1, 0, 2}); + const auto perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; + fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk"); + } else if (expr == "bhd,bhr->hdr" and arch_major == 10) { + // (batch_size, m, n, k): (h, d, r, b) + const auto perm_a = a.first.permute({1, 2, 0}); + const auto perm_sfa = a.second.permute({1, 2, 0}); + const auto perm_b = b.first.permute({1, 2, 0}); + const auto perm_sfb = b.second.permute({1, 2, 0}); + fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, d, c, recipe, "mn"); + } else { + DG_HOST_UNREACHABLE(fmt::format("Unsupported einsum expression: {}", expr)); + } +} +#endif + +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("einsum", &einsum, + py::arg("expr"), py::arg("a"), py::arg("b"), + py::arg("d"), py::arg("c") = std::nullopt, + py::arg("use_cublaslt") = false); + m.def("fp8_einsum", &fp8_einsum, + py::arg("expr"), py::arg("a"), py::arg("b"), + py::arg("d"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 128, 128)); +#endif +} + +} // namespace deep_gemm::einsum diff --git a/third_party/DeepGEMM/csrc/apis/gemm.hpp b/third_party/DeepGEMM/csrc/apis/gemm.hpp new file mode 100644 index 00000000..42622df7 --- /dev/null +++ b/third_party/DeepGEMM/csrc/apis/gemm.hpp @@ -0,0 +1,715 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp" +#include "../jit_kernels/impls/sm90_bf16_gemm.hpp" +#include "../jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp" +#include "../jit_kernels/impls/sm100_bf16_gemm.hpp" +#endif + +#include "../jit_kernels/impls/smxx_cublaslt.hpp" + +#include "layout.hpp" + +namespace deep_gemm::gemm { + +static bool early_return(const int& m, const int &n, const int& k, + const torch::Tensor& d, const std::optional& c) { + // Do nothing if the problem is empty + if (m == 0 or n == 0) + return true; + + // Checks + const bool is_cd_same = c.has_value() and c->data_ptr() == d.data_ptr(); + if (is_cd_same) + DG_HOST_ASSERT(c->sizes() == d.sizes() and c->strides() == d.strides()); + if (c.has_value()) { + check_major_type_cd(c.value()); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(c.value().scalar_type() == torch::kFloat); + } + + // No accumulation + if (k == 0) { + if (not is_cd_same) + c.has_value() ? d.copy_(c.value()) : d.zero_(); + return true; + } + + // With accumulation, do copy before GEMM (assuming the GEMM kernel does not support different C/D) + if (c.has_value() and not is_cd_same) + d.copy_(c.value()); + return false; +} + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + +static void fp8_fp4_gemm_nt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[M, K] @ [N, K].T` + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); + if (fp8_requires_k_major()) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + } + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Early return for trivial cases + if (early_return(m, n, k, d, c)) + return; + + // Transform SFA and SFB into compute-required layout + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast); + + // Dispatch into different implements + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value()); + if (gran_n == 1) { + sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + const auto major_sfb = get_major_type_ab(sfb); + sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); + } + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void fp8_fp4_gemm_nn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_fp4_gemm_tn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, + {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); +} + +static void fp8_fp4_gemm_tt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); +} + +static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + if (fp8_requires_k_major()) + DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); + + // Type and shape checks + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Transform SFA and SFB into compute-required layout + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, num_groups, disable_ue8m0_cast); + + // Dispatch implementation + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + const auto major_sfb = get_major_type_ab(sfb); + sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, major_sfb, + compiled_dims, use_psum_layout, expected_m_for_psum_layout); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b, + compiled_dims, use_psum_layout, expected_m_for_psum_layout); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout) { + m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, + d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt); +} + +static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& expected_m, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto major_a = get_major_type_ab(a.first); + const auto major_b = get_major_type_ab(b.first); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [num_groups__, m_, n_] = get_shape<3>(d); + const auto num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Transform scaling factors + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, num_groups, num_groups, disable_ue8m0_cast); + + // Dispatch implementation + if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { + const auto major_sfb = get_major_type_ab(sfb); + sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims); + } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { + sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); + } +} + +static void k_grouped_fp8_gemm_tn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1); + + const int gran_k = std::get<2>(recipe); + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); + + // Shape checks + const auto [num_groups, m, n] = get_shape<3>(d); + const auto [sum_k_ , m_] = get_shape<2>(a.first); + const auto [sum_k__, n_] = get_shape<2>(b.first); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); + + // Early return for trivial cases + if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c)) + return; + + // Transform SF with padding + const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 10) { + sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, gran_k, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void k_grouped_fp8_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::tuple& recipe, + const std::string& compiled_dims) { + // Must be 1D1D kernel + DG_HOST_ASSERT(recipe == std::make_tuple(1, 1, 128)); + + // Shape checks + const auto [num_groups, m, n] = get_shape<3>(d); + const auto sum_mk = a.first.numel(); + const auto sum_nk = b.first.numel(); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(sum_mk == static_cast(sum_k) * m); + DG_HOST_ASSERT(sum_nk == static_cast(sum_k) * n); + + // Contiguity checks + DG_HOST_ASSERT(a.first.is_contiguous()); + DG_HOST_ASSERT(b.first.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); + + // Early return for trivial cases + if (early_return(m, n, accumulate(ks.begin(), ks.end(), 0), d, c)) + return; + + // Transform SF with padding + const auto sfa = layout::transform_k_grouped_sf_into_required_layout(a.second, ks, ks_tensor, recipe); + const auto sfb = layout::transform_k_grouped_sf_into_required_layout(b.second, ks, ks_tensor, recipe); + + // Allocate tensormap buffer + // `4` means the double buffering for both A and B operands (2 * 2) + const auto num_sms = device_runtime->get_num_sms(); + const auto tensor_map_buffer = torch::empty({num_sms * 4 * static_cast(sizeof(CUtensorMap))}, + a.first.options().dtype(torch::kByte)); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, + cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} +#endif + +#if DG_TENSORMAP_COMPATIBLE +static void bf16_gemm_nt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + // Shape must be `[M, K] @ [N, K].T` + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); + + // C/D must be N-major + check_major_type_cd(d); + + // Type and shape checks + const auto [m , k ] = get_shape<2>(a); + const auto [n , k_] = get_shape<2>(b); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); + + // Early return for trivial cases + if (early_return(m, n, k, d, c)) + return; + + // Dispatch into different implements + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_bf16_gemm(a, b, c, d, m, n, k, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void bf16_gemm_nn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a, b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tn(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c, compiled_dims); +} + +static void bf16_gemm_tt(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::optional& c, + const std::string& compiled_dims) { + bf16_gemm_nt(a.transpose(0, 1), b, d, c, compiled_dims); +} + +static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + // Shape must be `[M, K] @ [G, N, K].mT` + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); + + // Type and shape checks + const auto [m, k] = get_shape<2>(a); + const auto [num_groups, n, k_] = get_shape<3>(b); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } + + // D must be N-major + check_major_type_cd(d); + + // Do nothing if empty + if (m == 0) + return; + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, compiled_dims, + use_psum_layout, expected_m_for_psum_layout); + } else if (arch_major == 10) { + sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, compiled_dims, + use_psum_layout, expected_m_for_psum_layout); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout) { + m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2), + d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt); +} + +static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const torch::Tensor& masked_m, + const int& expected_m, const std::string& compiled_dims) { + // Shape must be `[G, M, K] @ [G, N, K].mT` + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(masked_m.is_contiguous()); + + // Type and shape checks + const auto [num_groups, m, k] = get_shape<3>(a); + const auto [num_groups_, n, k_] = get_shape<3>(b); + const auto [num_groups__, m_, n_] = get_shape<3>(d); + const auto num_groups___ = static_cast(masked_m.numel()); + DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); + + // D must be N-major + check_major_type_cd(d); + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_m_grouped_gemm_masked(a, b, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else if (arch_major == 10) { + sm100_m_grouped_bf16_gemm_masked(a, b, d, masked_m, + num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::optional& c, + const std::string& compiled_dims) { + // Shape checks + const auto [num_groups, m, n] = get_shape<3>(d); + const auto [sum_k_ , m_] = get_shape<2>(a); + const auto [sum_k__, n_] = get_shape<2>(b); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); + + // Contiguity checks + DG_HOST_ASSERT(a.is_contiguous()); + DG_HOST_ASSERT(b.is_contiguous()); + DG_HOST_ASSERT(d.is_contiguous()); + DG_HOST_ASSERT(c.has_value() and c.value().is_contiguous()); + + // Early return for trivial cases + if (early_return(m, n, std::accumulate(ks.begin(), ks.end(), 0), d, c)) + return; + + // Dispatch implementation + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else if (arch_major == 10) { + sm100_bf16_k_grouped_gemm(a, b, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} +#endif + +static void cublaslt_gemm_nt(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + // Shape must be `[M, K] @ [N, K].T` + const auto major_a = get_major_type_ab(a); + const auto major_b = get_major_type_ab(b); + + // Type and shape checks + const auto [m , k ] = get_shape<2>(a); + const auto [n , k_] = get_shape<2>(b); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); + + // Early return for trivial cases + if (early_return(m, n, k, d, c)) + return; + + cublaslt_gemm(a, b, d, m, n, k, major_a, major_b, c.has_value()); +} + +static void cublaslt_gemm_nn(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a, b.transpose(0, 1), d, c); +} + +static void cublaslt_gemm_tn(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a.transpose(0, 1), b.transpose(0, 1), d, c); +} + +static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& d, const std::optional& c) { + cublaslt_gemm_nt(a.transpose(0, 1), b, d, c); +} + +static void register_apis(pybind11::module_& m) { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + // FP8 FP4 GEMMs + m.def("fp8_fp4_gemm_nt", &fp8_fp4_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_gemm_nn", &fp8_fp4_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_gemm_tn", &fp8_fp4_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("fp8_fp4_gemm_tt", &fp8_fp4_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "mn", + py::arg("disable_ue8m0_cast") = false); + m.def("m_grouped_fp8_fp4_gemm_nt_contiguous", &m_grouped_fp8_fp4_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false, + py::arg("use_psum_layout") = false, + py::arg("expected_m_for_psum_layout") = std::nullopt); + m.def("m_grouped_fp8_fp4_gemm_nn_contiguous", &m_grouped_fp8_fp4_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false, + py::arg("use_psum_layout") = false); + m.def("m_grouped_fp8_fp4_gemm_nt_masked", &m_grouped_fp8_fp4_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); + m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); + m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("recipe") = std::make_tuple(1, 1, 128), + py::arg("compiled_dims") = "mn"); + + // FP8 GEMM alias names + m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt"); + m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn"); + m.attr("fp8_gemm_tn") = m.attr("fp8_fp4_gemm_tn"); + m.attr("fp8_gemm_tt") = m.attr("fp8_fp4_gemm_tt"); + m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous"); + m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous"); + m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked"); +#endif + +#if DG_TENSORMAP_COMPATIBLE + // BF16 GEMMs + m.def("bf16_gemm_nt", &bf16_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_nn", &bf16_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "nk"); + m.def("bf16_gemm_tn", &bf16_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("bf16_gemm_tt", &bf16_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), + py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); + m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", + py::arg("use_psum_layout") = false, + py::arg("expected_m_for_psum_layout") = std::nullopt); + m.def("m_grouped_bf16_gemm_nn_contiguous", &m_grouped_bf16_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", + py::arg("use_psum_layout") = false); + m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), + py::arg("expected_m"), py::arg("compiled_dims") = "nk"); + m.def("k_grouped_bf16_gemm_tn_contiguous", &k_grouped_bf16_gemm_tn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), + py::arg("ks_tensor"), py::arg("c") = std::nullopt, + py::arg("compiled_dims") = "mn"); +#endif + + // cuBLASLt GEMMs + m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); + m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt); +} + +} // namespace deep_gemm::gemm diff --git a/third_party/DeepGEMM/csrc/apis/hyperconnection.hpp b/third_party/DeepGEMM/csrc/apis/hyperconnection.hpp new file mode 100644 index 00000000..1a13984d --- /dev/null +++ b/third_party/DeepGEMM/csrc/apis/hyperconnection.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp" +#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp" +#endif + +namespace deep_gemm::hyperconnection { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const std::optional& num_splits) { + // A and B must be K-major, D must be N-major + DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K); + check_major_type_cd(d); + + // S must be contiguous + DG_HOST_ASSERT(sqr_sum.is_contiguous()); + + // Type and shape checks + const auto [m, k ] = get_shape<2>(a); + const auto [n, k_] = get_shape<2>(b); + if (num_splits.has_value()) { + const auto [num_splits_, m_, n_] = get_shape<3>(d); + const auto [num_splits__, m__] = get_shape<2>(sqr_sum); + DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } else { + const auto [m_, n_] = get_shape<2>(d); + const auto [m__] = get_shape<1>(sqr_sum); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat); + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Dispatch into different implements + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else if (arch_major == 10) { + sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +#endif + +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"), + py::arg("num_splits") = std::nullopt); +#endif +} + +} // namespace deep_gemm::hyperconnection diff --git a/third_party/DeepGEMM/csrc/apis/layout.hpp b/third_party/DeepGEMM/csrc/apis/layout.hpp new file mode 100644 index 00000000..b404241a --- /dev/null +++ b/third_party/DeepGEMM/csrc/apis/layout.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include "../jit_kernels/heuristics/runtime.hpp" +#include "../utils/layout.hpp" +#include "../utils/compatibility.hpp" + +#if DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/smxx_layout.hpp" +#endif + +namespace deep_gemm::layout { + +#if DG_TENSORMAP_COMPATIBLE +static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const std::variant, + std::tuple>& recipe, + const std::optional& num_groups, + const std::optional& is_sfa, + const bool& disable_ue8m0_cast) { + const auto arch_major = device_runtime->get_arch_major(); + + // Get granularity MN/K from recipe + int gran_mn, gran_k; + if (auto p = std::get_if>(&recipe)) { + DG_HOST_ASSERT(is_sfa.has_value()); + gran_mn = is_sfa.value() ? std::get<0>(*p) : std::get<1>(*p); + gran_k = std::get<2>(*p); + } else if (auto p = std::get_if>(&recipe)) { + DG_HOST_ASSERT(not is_sfa.has_value()); + std::tie(gran_mn, gran_k) = *p; + } else { + DG_HOST_UNREACHABLE("Invalid recipe"); + } + + // Pre-transform checks + check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups); + + // (FP32, 1, 128) on SM90: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return get_mn_major_tma_aligned_tensor(sf); + + // (FP32, 128, 128) on SM90: no need to transform, check SFB requirements + if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); + + // (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) { + DG_HOST_ASSERT(not disable_ue8m0_cast); + const auto broadcasted = gran_mn == 1 ? sf : + sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn)); + return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); + } + + // (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10) + return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); + + DG_HOST_UNREACHABLE("Unknown SF transformation"); +} + +static std::tuple transform_sf_pair_into_required_layout( + const torch::Tensor& sfa, const torch::Tensor& sfb, + const int& m, const int& n, const int& k, + std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::optional& num_groups_a, + const std::optional& num_groups_b, + const bool& disable_ue8m0_cast = false) { + // Use default recipe, if none is specified + if (not recipe_a.has_value() and not recipe.has_value()) + recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type()); + + // Must be either 'recipe' or the 'recipe_a' + 'recipe_b' pair. + DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value()); + DG_HOST_ASSERT(recipe_a.has_value() != recipe.has_value()); + + // Transform SFA and SFB layout + const auto transformed_sfa = recipe.has_value() ? transform_sf_into_required_layout(sfa, m, k, recipe.value(), num_groups_a, true, disable_ue8m0_cast) + : transform_sf_into_required_layout(sfa, m, k, recipe_a.value(), num_groups_a, std::nullopt, disable_ue8m0_cast); + const auto transformed_sfb = recipe.has_value() ? transform_sf_into_required_layout(sfb, n, k, recipe.value(), num_groups_b, false, disable_ue8m0_cast) + : transform_sf_into_required_layout(sfb, n, k, recipe_b.value(), num_groups_b, std::nullopt, disable_ue8m0_cast); + const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value()); + const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value()); + return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b); +} + +static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, + const std::vector& ks, + const torch::Tensor& ks_tensor, + const std::tuple& recipe) { + DG_HOST_ASSERT(sf.dim() == 2); + DG_HOST_ASSERT(std::get<0>(recipe) == 1 and std::get<1>(recipe) == 1); + + const int gran_k = std::get<2>(recipe); + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); + + const auto arch_major = device_runtime->get_arch_major(); + + // FP32 on SM90 + if (sf.scalar_type() == torch::kFloat and arch_major == 9) + return get_mn_major_tma_aligned_tensor(sf); + + // FP32 on SM100 + if (sf.scalar_type() == torch::kFloat and arch_major == 10) + return get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks, gran_k); + + // INT on SM100 + if (sf.scalar_type() == torch::kInt and arch_major == 10) + DG_HOST_UNREACHABLE("Unimplemented"); + + DG_HOST_UNREACHABLE("Unknown cases"); +} + +#endif + +static void register_apis(pybind11::module_& m) { +#if DG_TENSORMAP_COMPATIBLE + m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, + py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"), + py::arg("num_groups") = std::nullopt, + py::arg("is_sfa") = std::nullopt, + py::arg("disable_ue8m0_cast") = false); + + m.def("get_tma_aligned_size", &get_tma_aligned_size); + m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor); + m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor); + m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); +#endif + + m.def("set_mk_alignment_for_contiguous_layout", [&](const int& new_value) { + heuristics_runtime->set_mk_alignment_for_contiguous_layout(new_value); + }); + m.def("get_mk_alignment_for_contiguous_layout", [&]() { + return heuristics_runtime->get_mk_alignment_for_contiguous_layout(); + }); + m.def("get_theoretical_mk_alignment_for_contiguous_layout", [&](const std::optional& expected_m) { + return heuristics_runtime->get_theoretical_mk_alignment_for_contiguous_layout(expected_m); + }, py::arg("expected_m") = std::nullopt); +} + +} // namespace deep_gemm::layout diff --git a/third_party/DeepGEMM/csrc/apis/mega.hpp b/third_party/DeepGEMM/csrc/apis/mega.hpp new file mode 100644 index 00000000..efc3a780 --- /dev/null +++ b/third_party/DeepGEMM/csrc/apis/mega.hpp @@ -0,0 +1,235 @@ +#pragma once + +#include +#include + +#if DG_TENSORMAP_COMPATIBLE +#include "../jit/compiler.hpp" +#endif +#include "../jit/device_runtime.hpp" +#include "../jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp" + +namespace deep_gemm::mega { + +static int get_token_alignment_for_mega_moe() { + return layout::kLCMCandidateBlockM; +} + +static std::tuple(const torch::Tensor&)>> +get_symm_buffer_size_for_mega_moe( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const bool& use_fp8_dispatch, const std::string& activation) { + DG_HOST_ASSERT(num_experts % num_ranks == 0); + + // Workspace bytes + const auto workspace = layout::Workspace(nullptr, num_ranks, num_experts, num_max_tokens_per_rank, num_topk); + + // Layouts + const auto fp8_token_layout = layout::Data(hidden); + const auto bf16_token_layout = layout::Data(hidden * 2); + const auto fp8_intermediate_token_layout = layout::Data(intermediate_hidden); + const auto fp8_sf_layout = layout::Data(hidden / 32); + const auto fp8_intermediate_sf_layout = layout::Data(intermediate_hidden / 32); + const auto input_topk_idx_layout = layout::Data(num_topk * sizeof(int64_t), false); + const auto input_topk_weights_layout = layout::Data(num_topk * sizeof(float), false); + const auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Input buffers + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_tokens_per_rank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, num_max_tokens_per_rank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, num_max_tokens_per_rank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, num_max_tokens_per_rank, + input_topk_idx_buffer.get_end_ptr()); + + // Buffer configs + const auto num_max_pool_tokens = static_cast(workspace.num_max_pool_tokens); + int num_max_padded_sf_pool_tokens = 0; + for (int block_m: layout::kCandidateBlockM) { + num_max_padded_sf_pool_tokens = std::max( + num_max_padded_sf_pool_tokens, + layout::get_num_padded_sf_pool_tokens(num_max_pool_tokens, block_m) + ); + } + + // L1 input buffer + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, num_max_pool_tokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, num_max_padded_sf_pool_tokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, num_max_pool_tokens, + l1_sf_buffer.get_end_ptr()); + + // L2 input buffer + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, num_max_pool_tokens, + l1_topk_weights_buffer.get_end_ptr()); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, num_max_padded_sf_pool_tokens, + l2_token_buffer.get_end_ptr()); + + // Combine input buffer: BF16 tokens for cross-rank combine + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, num_topk, num_max_tokens_per_rank, + l2_sf_buffer.get_end_ptr()); + + // Check SF buffer requirements + DG_HOST_ASSERT(hidden % 128 == 0 and intermediate_hidden % 128 == 0); + DG_HOST_ASSERT(num_max_padded_sf_pool_tokens % 4 == 0); + + // Slice function: creates `(x, x_sf, topk_weights, topk_idx, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf)` tensor views from the raw buffer + // NOTES: `x_sf` is K-major, while `l1_acts_sf` and `l2_acts_sf` are M-major + auto slice_input_buffers = [=](const torch::Tensor& buffer) { + auto x = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_token_buffer.base)), + {num_max_tokens_per_rank, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto x_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_sf_buffer.base)), + {num_max_tokens_per_rank, hidden / 128}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto topk_idx = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_idx_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kInt64).device(buffer.device())); + auto topk_weights = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(input_topk_weights_buffer.base)), + {num_max_tokens_per_rank, num_topk}, + torch::TensorOptions().dtype(torch::kFloat32).device(buffer.device())); + auto l1_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_token_buffer.base)), + {num_max_pool_tokens, hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto l1_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l1_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + auto l2_acts = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_token_buffer.base)), + {num_max_pool_tokens, intermediate_hidden}, + torch::TensorOptions().dtype(torch::kFloat8_e4m3fn).device(buffer.device())); + auto l2_acts_sf = torch::from_blob( + math::advance_ptr(buffer.data_ptr(), reinterpret_cast(l2_sf_buffer.base)), + {num_max_padded_sf_pool_tokens, intermediate_hidden / 128}, + {1, num_max_padded_sf_pool_tokens}, + torch::TensorOptions().dtype(torch::kInt).device(buffer.device())); + return std::make_tuple(x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf); + }; + return {reinterpret_cast(combine_token_buffer.get_end_ptr()), slice_input_buffers}; +} + +static void fp8_fp4_mega_moe( + const torch::Tensor& y, + const std::tuple& l1_weights_tuple, + const std::tuple& l2_weights_tuple, + const std::optional& cumulative_local_expert_recv_stats, + const torch::Tensor& sym_buffer, + const std::vector& sym_buffer_ptrs, const int& rank_idx, + const int& num_max_tokens_per_rank, + const int& num_experts, const int& num_topk, + const std::tuple& recipe, + const std::string& activation, + const std::optional& activation_clamp_opt, + const bool& fast_math +) { + const auto [l1_weights, l1_weights_sf] = l1_weights_tuple; + const auto [l2_weights, l2_weights_sf] = l2_weights_tuple; + + // Config checks + const auto num_tokens = static_cast(y.size(0)); + const auto [rm, rn, rk] = recipe; + DG_HOST_ASSERT(rm == 1 and rn == 1 and rk == 32); + DG_HOST_ASSERT(activation == "swiglu"); + + // Activation checks + const auto activation_clamp = + activation_clamp_opt.value_or(std::numeric_limits::infinity()); + DG_HOST_ASSERT(activation_clamp >= 0); + + // Tensor checks + DG_HOST_ASSERT(get_major_type_ab(l1_weights) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(l2_weights) == cute::UMMA::Major::K); + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_experts_per_rank, intermediate_hidden_2, hidden] = + check_grouped_ab_fp8_fp4(l1_weights, cute::UMMA::Major::K, arch_major); + const auto [num_experts_per_rank_, hidden_, intermediate_hidden] = + check_grouped_ab_fp8_fp4(l2_weights, cute::UMMA::Major::K, arch_major); + DG_HOST_ASSERT(num_tokens <= num_max_tokens_per_rank); + DG_HOST_ASSERT(num_experts_per_rank == num_experts_per_rank_); + DG_HOST_ASSERT(hidden == hidden_); + DG_HOST_ASSERT(intermediate_hidden_2 == 2 * intermediate_hidden); + DG_HOST_ASSERT(l1_weights.is_contiguous() and l2_weights.is_contiguous()); + + // Check weight SF layout for UE8M0 packing, MN-major, and TMA alignment + constexpr int kGranMN = 1, kGranK = 32; + check_sf_layout(l1_weights_sf, intermediate_hidden * 2, hidden, kGranMN, kGranK, + num_experts_per_rank, true, false, torch::kInt); + check_sf_layout(l2_weights_sf, hidden, intermediate_hidden, kGranMN, kGranK, + num_experts_per_rank, true, false, torch::kInt); + + // Check stats counter + if (cumulative_local_expert_recv_stats.has_value()) { + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->scalar_type() == torch::kInt); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->numel() == num_experts_per_rank); + DG_HOST_ASSERT(cumulative_local_expert_recv_stats->is_contiguous()); + } + + // Check buffer bytes + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts_ = num_experts_per_rank * num_ranks; + const auto [num_required_bytes, slice] = get_symm_buffer_size_for_mega_moe( + num_ranks, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + true, activation); + DG_HOST_ASSERT(sym_buffer.nbytes() >= static_cast(num_required_bytes)); + DG_HOST_ASSERT(num_experts == num_experts_); + + // Already registered tensors + const auto [x, x_sf, topk_idx, topk_weights, l1_acts, l1_acts_sf, l2_acts, l2_acts_sf] = slice(sym_buffer); + + // Dispatch into different architectures + if (arch_major == 10) { + sm100_fp8_fp4_mega_moe(y, + l1_acts, l1_acts_sf, + l2_acts, l2_acts_sf, + l1_weights, l2_weights, + l1_weights_sf, l2_weights_sf, + cumulative_local_expert_recv_stats, + sym_buffer_ptrs, + rank_idx, num_max_tokens_per_rank, + num_experts_per_rank, + num_tokens, num_topk, + hidden, intermediate_hidden, + activation_clamp, fast_math); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } + + // Zero the entire symmetric buffer for debug mode + // NOTES: caller must re-copy inputs into the buffer before each kernel call + if (get_env("DG_COMM_KERNEL_DEBUG")) + sym_buffer.zero_(); +} + +static void register_apis(pybind11::module_& m) { +#if DG_TENSORMAP_COMPATIBLE + m.def("get_token_alignment_for_mega_moe", &get_token_alignment_for_mega_moe); + m.def("get_symm_buffer_size_for_mega_moe", &get_symm_buffer_size_for_mega_moe); + m.def("fp8_fp4_mega_moe", &fp8_fp4_mega_moe); +#endif +} + +} // namespace deep_gemm::mega diff --git a/third_party/DeepGEMM/csrc/apis/runtime.hpp b/third_party/DeepGEMM/csrc/apis/runtime.hpp new file mode 100644 index 00000000..58fef941 --- /dev/null +++ b/third_party/DeepGEMM/csrc/apis/runtime.hpp @@ -0,0 +1,51 @@ +#pragma once + +#if DG_TENSORMAP_COMPATIBLE +#include "../jit/compiler.hpp" +#endif +#include "../jit/device_runtime.hpp" +#include "../jit_kernels/heuristics/runtime.hpp" + +namespace deep_gemm::runtime { + +static void register_apis(pybind11::module_& m) { + m.def("set_num_sms", [&](const int& new_num_sms) { + device_runtime->set_num_sms(new_num_sms); + }); + m.def("get_num_sms", [&]() { + return device_runtime->get_num_sms(); + }); + m.def("set_tc_util", [&](const int& new_tc_util) { + device_runtime->set_tc_util(new_tc_util); + }); + m.def("get_tc_util", [&]() { + return device_runtime->get_tc_util(); + }); + m.def("set_pdl", [&](const bool& new_enable_pdl) { + device_runtime->set_pdl(new_enable_pdl); + }); + m.def("get_pdl", [&]() { + return device_runtime->get_pdl(); + }); + m.def("set_ignore_compile_dims", [&](const bool& new_value) { + heuristics_runtime->set_ignore_compile_dims(new_value); + }); + m.def("set_block_size_multiple_of", [&](const std::variant>& new_value) { + if (std::holds_alternative(new_value)) { + auto x = std::get(new_value); + heuristics_runtime->set_block_size_multiple_of(x, x); + } else { + auto [x, y] = std::get>(new_value); + heuristics_runtime->set_block_size_multiple_of(x, y); + } + }); + m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) { +#if DG_TENSORMAP_COMPATIBLE + Compiler::prepare_init(library_root_path, cuda_home_path_by_python); + KernelRuntime::prepare_init(cuda_home_path_by_python); + IncludeParser::prepare_init(library_root_path); +#endif + }); +} + +} // namespace deep_gemm::runtime diff --git a/third_party/DeepGEMM/csrc/indexing/main.cu b/third_party/DeepGEMM/csrc/indexing/main.cu new file mode 100644 index 00000000..a42b66f9 --- /dev/null +++ b/third_party/DeepGEMM/csrc/indexing/main.cu @@ -0,0 +1,35 @@ +// GEMM kernels +#include +#include +#include +#include +#include + +// Attention kernels +#include +#include +#include +#include +#include +#include + +// Einsum kernels +#include +#include + +// Hyperconnection kernels +#include +#include + +// Layout kernels +#include +#include + +// Mega kernels +#include + +using namespace deep_gemm; + +int main() { + return 0; +} diff --git a/third_party/DeepGEMM/csrc/jit/cache.hpp b/third_party/DeepGEMM/csrc/jit/cache.hpp new file mode 100644 index 00000000..ddc763d0 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit/cache.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +#include "kernel_runtime.hpp" + +namespace deep_gemm { + +class KernelRuntimeCache { + std::unordered_map> cache; + +public: + // TODO: consider cache capacity + KernelRuntimeCache() = default; + + std::shared_ptr get(const std::filesystem::path& dir_path) { + // Hit the runtime cache + if (const auto iterator = cache.find(dir_path); iterator != cache.end()) + return iterator->second; + + if (KernelRuntime::check_validity(dir_path)) + return cache[dir_path] = std::make_shared(dir_path); + return nullptr; + } +}; + +static auto kernel_runtime_cache = std::make_shared(); + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit/compiler.hpp b/third_party/DeepGEMM/csrc/jit/compiler.hpp new file mode 100644 index 00000000..7d85a5f5 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit/compiler.hpp @@ -0,0 +1,362 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/hash.hpp" +#include "../utils/lazy_init.hpp" +#include "../utils/system.hpp" +#include "cache.hpp" +#include "device_runtime.hpp" +#include "include_parser.hpp" + +namespace deep_gemm { + +class Compiler { +public: + static std::filesystem::path library_root_path; + static std::filesystem::path library_include_path; + static std::filesystem::path cuda_home; + static std::filesystem::path cuobjdump_path; + + static void prepare_init(const std::string& library_root_path, + const std::string& cuda_home_path_by_python) { + Compiler::library_root_path = library_root_path; + Compiler::library_include_path = Compiler::library_root_path / "include"; + Compiler::cuda_home = cuda_home_path_by_python; + Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump"; + } + + std::string signature, flags; + std::filesystem::path cache_dir_path; + + Compiler() { + // Check `prepare_init` + DG_HOST_ASSERT(not library_root_path.empty()); + DG_HOST_ASSERT(not library_include_path.empty()); + DG_HOST_ASSERT(not cuda_home.empty()); + DG_HOST_ASSERT(not cuobjdump_path.empty()); + + // Cache settings + cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; + if (const auto env_cache_dir_path = get_env("DG_JIT_CACHE_DIR"); not env_cache_dir_path.empty()) + cache_dir_path = env_cache_dir_path; + + // The compiler flags applied to all derived compilers + signature = "unknown-compiler"; + flags = fmt::format("-std=c++{} --diag-suppress=39,161,174,177,186,940 " + "--ptxas-options=--register-usage-level=10", + get_env("DG_JIT_CPP_STANDARD", 20)); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0) or get_env("DG_JIT_PTXAS_CHECK", 0)) + flags += " --ptxas-options=--verbose,--warn-on-local-memory-usage"; + if (get_env("DG_JIT_WITH_LINEINFO", 0)) + flags += " -Xcompiler -rdynamic -lineinfo"; + } + + virtual ~Compiler() = default; + + std::filesystem::path make_tmp_dir() const { + return make_dirs(cache_dir_path / "tmp"); + } + + static void fsync_path(const std::filesystem::path& path) { + const auto fd = ::open(path.c_str(), O_RDONLY); + if (fd >= 0) { + ::fsync(fd); + ::close(fd); + } + } + + // Recursively fsync a directory: files and subdirectories first (bottom-up), then the directory itself + // NOTES: ensures data and directory entries are visible on other nodes in distributed filesystems + static void fsync_dir(const std::filesystem::path& dir_path) { // NOLINT(*-no-recursion) + for (const auto& entry: std::filesystem::directory_iterator(dir_path)) { + if (entry.is_directory()) + fsync_dir(entry.path()); + else if (entry.is_regular_file()) + fsync_path(entry.path()); + } + fsync_path(dir_path); + } + + static void put(const std::filesystem::path& path, const std::string& data) { + std::ofstream out(path, std::ios::binary); + DG_HOST_ASSERT(out.write(data.data(), data.size())); + out.close(); + + // NOTES: fsync to ensure the data is visible to other processes (e.g., NVCC) + // on distributed filesystems, where `close()` alone does not guarantee persistence + fsync_path(path); + } + + std::shared_ptr build(const std::string& name, const std::string& code) const { + const auto kernel_signature = fmt::format("{}$${}$${}$${}", name, signature, flags, code); + const auto dir_path = cache_dir_path / "cache" / fmt::format("kernel.{}.{}", name, get_hex_digest(kernel_signature)); + + // Hit the runtime cache + if (const auto runtime = kernel_runtime_cache->get(dir_path); runtime != nullptr) + return runtime; + + // Compile into a temporary directory, then atomically rename the whole directory + // NOTES: renaming a directory is atomic on both local and distributed filesystems, + // avoiding the stale inode issue that occurs when renaming individual files + const auto tmp_dir_path = make_tmp_dir() / get_uuid(); + make_dirs(tmp_dir_path); + + // Compile into the temporary directory + const auto tmp_cubin_path = tmp_dir_path / "kernel.cubin"; + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_PTX")) { + const auto tmp_ptx_path = tmp_dir_path / "kernel.ptx"; + compile(code, tmp_dir_path, tmp_cubin_path, tmp_ptx_path); + } else { + compile(code, tmp_dir_path, tmp_cubin_path); + } + + // Disassemble if needed + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_SASS")) { + const auto tmp_sass_path = tmp_dir_path / "kernel.sass"; + disassemble(tmp_cubin_path, tmp_sass_path); + } + + // Fsync before rename to ensure visibility on distributed filesystems + fsync_dir(tmp_dir_path); + + // Atomically rename the temporary directory to the final cache path + // NOTES: if another rank already created dir_path, rename will fail — that's fine + make_dirs(dir_path.parent_path()); + std::error_code error_code; + std::filesystem::rename(tmp_dir_path, dir_path, error_code); + if (error_code) { + // Another rank beat us, then clean up our dir and use the existing one + // NOTES: avoid `std::filesystem::remove_all` here — it can segfault on + // distributed filesystems, when concurrent processes operate + // on the same parent directory, causing stale directory entries + safe_remove_all(tmp_dir_path); + } + + // Put into the runtime cache + const auto runtime = kernel_runtime_cache->get(dir_path); + DG_HOST_ASSERT(runtime != nullptr); + return runtime; + } + + static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) { + // Disassemble the CUBIN file to SASS + const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str()); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running cuobjdump command: %s\n", command.c_str()); + const auto [return_code, output] = call_external_command(command); + if (return_code != 0) { + printf("cuobjdump failed: %s\n", output.c_str()); + DG_HOST_ASSERT(false and "cuobjdump failed"); + } + } + + virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional &ptx_path = std::nullopt) const = 0; +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path); + +class NVCCCompiler final: public Compiler { + std::filesystem::path nvcc_path; + + std::pair get_nvcc_version() const { + DG_HOST_ASSERT(std::filesystem::exists(nvcc_path)); + + // Call the version command + const auto command = std::string(nvcc_path) + " --version"; + const auto [return_code, output] = call_external_command(command); + DG_HOST_ASSERT(return_code == 0); + + // The version should be at least 12.3, for the best performance with 12.9 + int major, minor; + std::smatch match; + DG_HOST_ASSERT(std::regex_search(output, match, std::regex(R"(release (\d+\.\d+))"))); + std::sscanf(match[1].str().c_str(), "%d.%d", &major, &minor); + DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVCC version should be >= 12.3"); + if (major == 12 and minor < 9) + printf("Warning: please use at least NVCC 12.9 for the best DeepGEMM performance\n"); + return {major, minor}; + } + +public: + NVCCCompiler() { + // Override the compiler signature + nvcc_path = cuda_home / "bin" / "nvcc"; + if (const auto env_nvcc_path = get_env("DG_JIT_NVCC_COMPILER"); not env_nvcc_path.empty()) + nvcc_path = env_nvcc_path; + const auto [nvcc_major, nvcc_minor] = get_nvcc_version(); + signature = fmt::format("NVCC{}.{}", nvcc_major, nvcc_minor); + + // The override the compiler flags + // Only NVCC >= 12.9 supports arch-specific family suffix + const auto arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); + flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " + "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " + "-O3 --expt-relaxed-constexpr --expt-extended-lambda", + flags, library_include_path.c_str(), arch); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { + // Write the code into the cache directory + const auto code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Compile + // Avoid cwd files shadowing C++ standard library headers + const auto compile_dir = make_tmp_dir(); + const auto command = fmt::format("cd {} && {} {} -cubin -o {} {}", + compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running NVCC command: %s\n", command.c_str()); + const auto [return_code, output] = call_external_command(command); + if (return_code != 0) { + printf("NVCC compilation failed: %s\n", output.c_str()); + DG_HOST_ASSERT(false and "NVCC compilation failed"); + } + + // Compile to PTX if needed + if (ptx_path.has_value()) { + const auto ptx_command = fmt::format("cd {} && {} {} -ptx -o {} {}", + compile_dir.c_str(), nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running NVCC PTX command: %s\n", ptx_command.c_str()); + const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command); + if (ptx_return_code != 0) { + printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str()); + DG_HOST_ASSERT(false and "NVCC PTX compilation failed"); + } + } + + // Check local memory usage + if (get_env("DG_JIT_PTXAS_CHECK", 0)) + DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)"))); + + // Print PTXAS log + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PTXAS_VERBOSE", 0)) + printf("%s", output.c_str()); + } +}; + +class NVRTCCompiler final: public Compiler { +public: + NVRTCCompiler() { + // Override the compiler signature + int major, minor; + DG_NVRTC_CHECK(nvrtcVersion(&major, &minor)); + signature = fmt::format("NVRTC{}.{}", major, minor); + DG_HOST_ASSERT((major > 12 or (major == 12 and minor >= 3)) and "NVRTC version should be >= 12.3"); + + // Build include directories list + std::string include_dirs; + include_dirs += fmt::format("-I{} ", library_include_path.string()); + include_dirs += fmt::format("-I{} ", (cuda_home / "include").string()); + + // Add PCH support for version 12.8 and above + // NOTES: PCH is vital for compilation speed + std::string pch_flags; + if (major > 12 or minor >= 8) { + pch_flags = "--pch "; + if (get_env("DG_JIT_DEBUG", 0)) + pch_flags += "--pch-verbose=true "; + } + + // Override the compiler flags + // Only NVRTC >= 12.9 supports arch-specific family suffix + const auto arch = device_runtime->get_arch(false, major > 12 or minor >= 9); + flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128", + flags, include_dirs, arch, pch_flags); + } + + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { + // Write the code into the cache directory + const auto code_path = dir_path / "kernel.cu"; + put(code_path, code); + + // Parse compilation options + std::istringstream iss(flags); + std::vector options; + std::string option; + while (iss >> option) + options.push_back(option); + + // Convert to C-style string array for NVRTC + std::vector option_cstrs; + for (const auto& opt: options) + option_cstrs.push_back(opt.c_str()); + + // Print compiler command if requested + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) { + printf("Compiling JIT runtime with NVRTC options: "); + for (const auto& opt: options) + printf("%s ", opt.c_str()); + printf("\n"); + } + + // Create NVRTC program and compile + nvrtcProgram program; + DG_NVRTC_CHECK(nvrtcCreateProgram(&program, code.c_str(), "kernel.cu", 0, nullptr, nullptr)); + const auto compile_result = nvrtcCompileProgram(program, static_cast(option_cstrs.size()), option_cstrs.data()); + + // Get and print compiler log + size_t log_size; + DG_NVRTC_CHECK(nvrtcGetProgramLogSize(program, &log_size)); + if (get_env("DG_JIT_DEBUG", 0) or compile_result != NVRTC_SUCCESS) { + if (compile_result != NVRTC_SUCCESS) + DG_HOST_ASSERT(log_size > 1); + if (log_size > 1) { + std::string compilation_log(log_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetProgramLog(program, compilation_log.data())); + printf("NVRTC log: %s\n", compilation_log.c_str()); + } + } + + if (ptx_path.has_value()) { + // Get PTX size and data if needed + size_t ptx_size; + DG_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); + std::string ptx_data(ptx_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetPTX(program, ptx_data.data())); + + // Write into the file system + put(ptx_path.value(), ptx_data); + } + + // Get CUBIN size and data + size_t cubin_size; + DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size)); + std::string cubin_data(cubin_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetCUBIN(program, cubin_data.data())); + + // Write into the file system + put(cubin_path, cubin_data); + + // Cleanup + DG_NVRTC_CHECK(nvrtcDestroyProgram(&program)); + } +}; + +static auto compiler = LazyInit([]() -> std::shared_ptr { + if (get_env("DG_JIT_USE_NVRTC", 0)) { + return std::make_shared(); + } else { + return std::make_shared(); + } +}); + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit/device_runtime.hpp b/third_party/DeepGEMM/csrc/jit/device_runtime.hpp new file mode 100644 index 00000000..2321aded --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit/device_runtime.hpp @@ -0,0 +1,138 @@ +#pragma once + +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/lazy_init.hpp" + +#define PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3)) + +namespace deep_gemm { + +class DeviceRuntime { + int num_sms = 0, tc_util = 0; + bool enable_pdl = false; + std::shared_ptr cached_prop; + + // cuBLASLt utils + static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024; + +public: + // Create the cuBLASLt handle ourselves + cublasLtHandle_t cublaslt_handle; + torch::Tensor cublaslt_workspace; + bool use_pytorch_managed_cublaslt_handle; + bool use_temp_cublaslt_workspace; + + explicit DeviceRuntime() { + + // Whether to use PyTorch cuBLASLt + // By default, we don't use it, + // as `at::cuda::getCurrentCUDABlasLtHandle` has large CPU overhead with some PyTorch versions + use_pytorch_managed_cublaslt_handle = get_env("DG_USE_PYTORCH_CUBLASLT_HANDLE", 0) > 0; +#if not PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE + DG_HOST_ASSERT(not use_pytorch_managed_cublaslt_handle and "PyTorch does not support to get cuBLASLt handle"); +#endif + + // Whether to create workspace tensor on each call instead of holding one. + // Enabled by compute-sanitizer tests, which trigger `cudaErrorCudartUnloading` + // when the workspace tensor is destructed after CUDA driver shutdown. + use_temp_cublaslt_workspace = get_env("DG_USE_TEMP_CUBLASLT_WORKSPACE", 0) > 0; + + if (not use_pytorch_managed_cublaslt_handle) + DG_CUBLASLT_CHECK(cublasLtCreate(&cublaslt_handle)); + + if (not use_temp_cublaslt_workspace) + cublaslt_workspace = torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)); + } + + ~DeviceRuntime() noexcept(false) { + if (not use_pytorch_managed_cublaslt_handle) + DG_CUBLASLT_CHECK(cublasLtDestroy(cublaslt_handle)); + } + + cublasLtHandle_t get_cublaslt_handle() const { +#if PYTORCH_SUPPORTS_GET_CUBLASLT_HANDLE + if (use_pytorch_managed_cublaslt_handle) + return at::cuda::getCurrentCUDABlasLtHandle(); +#endif + + // Self-managed handle + return cublaslt_handle; + } + + torch::Tensor get_cublaslt_workspace() const { + if (use_temp_cublaslt_workspace) + return torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)); + return cublaslt_workspace; + } + + std::shared_ptr get_prop() { + if (cached_prop == nullptr) { + int device_idx; + cudaDeviceProp prop; + DG_CUDA_RUNTIME_CHECK(cudaGetDevice(&device_idx)); + DG_CUDA_RUNTIME_CHECK(cudaGetDeviceProperties(&prop, device_idx)); + cached_prop = std::make_shared(prop); + } + return cached_prop; + } + + std::pair get_arch_pair() { + const auto prop = get_prop(); + return {prop->major, prop->minor}; + } + + std::string get_arch(const bool& number_only = false, + const bool& support_arch_family = false) { + const auto [major, minor] = get_arch_pair(); + if (major == 10 and minor != 1) { + if (number_only) + return "100"; + return support_arch_family ? "100f" : "100a"; + } + return std::to_string(major * 10 + minor) + (number_only ? "" : "a"); + } + + int get_arch_major() { + return get_arch_pair().first; + } + + void set_num_sms(const int& new_num_sms) { + DG_HOST_ASSERT(0 <= new_num_sms and new_num_sms <= get_prop()->multiProcessorCount); + num_sms = new_num_sms; + } + + int get_num_sms() { + if (num_sms == 0) + num_sms = get_prop()->multiProcessorCount; + return num_sms; + } + + int get_l2_cache_size() { + return get_prop()->l2CacheSize; + } + + void set_tc_util(const int& new_tc_util) { + DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100); + tc_util = new_tc_util; + } + + int get_tc_util() const { + return tc_util == 0 ? 100 : tc_util; + } + + void set_pdl(const bool& new_enable_pdl) { + enable_pdl = new_enable_pdl; + } + + bool get_pdl() const { + return enable_pdl; + } +}; + +static auto device_runtime = LazyInit([](){ return std::make_shared(); }); + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit/handle.hpp b/third_party/DeepGEMM/csrc/jit/handle.hpp new file mode 100644 index 00000000..be3bc31c --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit/handle.hpp @@ -0,0 +1,222 @@ +#pragma once + +#include +#include +#include +#include + +#include "../utils/exception.hpp" +#include "../utils/compatibility.hpp" + +namespace deep_gemm { + +// Lazy loading all driver symbols +static void* get_driver_handle() { + static void* handle = nullptr; + if (handle == nullptr) { + handle = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_LOCAL); + DG_HOST_ASSERT(handle != nullptr and "Failed to load CUDA driver `libcuda.so.1`"); + } + return handle; +} + +// Macro to define wrapper functions named `lazy_cu{API name}` +#define DECL_LAZY_CUDA_DRIVER_FUNCTION(name) \ +template \ +static auto lazy_##name(Args&&... args) -> decltype(name(args...)) { \ + using FuncType = decltype(&(name)); \ + static FuncType func = nullptr; \ + if (func == nullptr) { \ + func = reinterpret_cast(dlsym(get_driver_handle(), #name)); \ + DG_HOST_ASSERT(func != nullptr and "Failed to load CUDA driver API"); \ + } \ + return func(std::forward(args)...); \ +} + +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorName); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuGetErrorString); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuFuncSetAttribute); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryLoadFromFile); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryUnload); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuKernelGetFunction); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx); +DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled); + +#if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API) + +// Use CUDA runtime API +using LibraryHandle = cudaLibrary_t; +using KernelHandle = cudaKernel_t; +using LaunchConfigHandle = cudaLaunchConfig_t; +using LaunchAttrHandle = cudaLaunchAttribute; + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_RUNTIME_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel{}; + DG_CUDA_RUNTIME_CHECK(cudaLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + DG_CUDA_RUNTIME_CHECK(cudaLibraryGetKernel(&kernel, library, func_name.c_str())); + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { + const auto error = cudaLibraryUnload(library); + DG_HOST_ASSERT(error == cudaSuccess or error == cudaErrorCudartUnloading); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) { + if (smem_size > 0) + DG_CUDA_RUNTIME_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + LaunchConfigHandle config; + config.gridDim = grid_dim; + config.blockDim = block_dim; + config.dynamicSmemBytes = smem_size; + config.stream = stream; + + // Create attributes + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attrs[2]; + config.numAttrs = 0; + config.attrs = attrs; + + // Cluster size + if (cluster_dim > 1) { + auto& attr = attrs[config.numAttrs ++]; + attr.id = cudaLaunchAttributeClusterDimension; + attr.val.clusterDim = {static_cast(cluster_dim), 1, 1}; + } + + // Dependent kernel launch + if (enable_pdl) { + auto& attr = attrs[config.numAttrs ++]; + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = 1; + } + + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return cudaLaunchKernelExC(&config, kernel, ptr_args); +} + +#else + +// Use CUDA driver API +using KernelHandle = CUfunction; +using LaunchConfigHandle = CUlaunchConfig; +using LaunchAttrHandle = CUlaunchAttribute; + +// `cuLibraryEnumerateKernels` is supported since CUDA Driver API 12.4 +#if CUDA_VERSION >= 12040 + #define DG_JIT_USE_LIBRARY_ENUM_KERNELS + DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryGetKernelCount); + DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLibraryEnumerateKernels); + using LibraryHandle = CUlibrary; +#else + using LibraryHandle = CUmodule; +#endif + +#define DG_CUDA_UNIFIED_CHECK DG_CUDA_DRIVER_CHECK + +static KernelHandle load_kernel(const std::filesystem::path& cubin_path, const std::string& func_name, + LibraryHandle *library_opt = nullptr) { + LibraryHandle library; + KernelHandle kernel; + +#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS + DG_CUDA_DRIVER_CHECK(lazy_cuLibraryLoadFromFile(&library, cubin_path.c_str(), nullptr, nullptr, 0, nullptr, nullptr, 0)); + unsigned int num_kernels; + DG_CUDA_DRIVER_CHECK(lazy_cuLibraryGetKernelCount(&num_kernels, library)); + if (num_kernels != 1) { + const auto dir_path = cubin_path.parent_path(); + printf("Corrupted JIT cache directory (expected 1 kernel, found %u): %s, " + "please run `rm -rf %s` and restart your task.\n", + num_kernels, dir_path.c_str(), dir_path.c_str()); + DG_HOST_ASSERT(false and "Corrupted JIT cache directory"); + } + + CUkernel cu_kernel; + DG_CUDA_DRIVER_CHECK(lazy_cuLibraryEnumerateKernels(&cu_kernel, 1, library)); + DG_CUDA_DRIVER_CHECK(lazy_cuKernelGetFunction(&kernel, cu_kernel)); +#else + DG_CUDA_DRIVER_CHECK(lazy_cuModuleLoad(&library, cubin_path.c_str())); + DG_CUDA_DRIVER_CHECK(lazy_cuModuleGetFunction(&kernel, library, func_name.c_str())); +#endif + + if (library_opt != nullptr) + *library_opt = library; + return kernel; +} + +static void unload_library(const LibraryHandle& library) { +#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS + const auto error = lazy_cuLibraryUnload(library); +#else + const auto error = lazy_cuModuleUnload(library); +#endif + DG_HOST_ASSERT(error == CUDA_SUCCESS or error == CUDA_ERROR_DEINITIALIZED); +} + +static LaunchConfigHandle construct_launch_config(const KernelHandle& kernel, + const cudaStream_t& stream, const int& smem_size, + const dim3& grid_dim, const dim3& block_dim, const int& cluster_dim, const bool& enable_pdl) { + if (smem_size > 0) + DG_CUDA_DRIVER_CHECK(lazy_cuFuncSetAttribute(kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem_size)); + + LaunchConfigHandle config; + config.gridDimX = grid_dim.x; + config.gridDimY = grid_dim.y; + config.gridDimZ = grid_dim.z; + config.blockDimX = block_dim.x; + config.blockDimY = block_dim.y; + config.blockDimZ = block_dim.z; + config.sharedMemBytes = smem_size; + config.hStream = stream; + + // Create attributes + // NOTES: must use `static` or the `attr` will be deconstructed + static LaunchAttrHandle attrs[2]; + config.numAttrs = 0; + config.attrs = attrs; + + // Cluster size + if (cluster_dim > 1) { + auto& attr = attrs[config.numAttrs ++]; + attr.id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION; + attr.value.clusterDim.x = static_cast(cluster_dim); + attr.value.clusterDim.y = 1; + attr.value.clusterDim.z = 1; + } + + // Dependent kernel launch + if (enable_pdl) { + auto& attr = attrs[config.numAttrs ++]; + attr.id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; + attr.value.programmaticStreamSerializationAllowed = 1; + } + + return config; +} + +template +static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& config, ActTypes&&... args) { + void *ptr_args[] = { &args... }; + return lazy_cuLaunchKernelEx(&config, kernel, ptr_args, nullptr); +} +#endif + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit/include_parser.hpp b/third_party/DeepGEMM/csrc/jit/include_parser.hpp new file mode 100644 index 00000000..99f2663c --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit/include_parser.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include + +#include "../utils/format.hpp" +#include "../utils/system.hpp" + +namespace deep_gemm { + +class IncludeParser { + std::unordered_map> cache; + + static std::vector get_includes(const std::string& code, const std::filesystem::path& file_path = "") { + std::vector includes; + const std::regex pattern(R"(#\s*include\s*[<"][^>"]+[>"])"); + std::sregex_iterator iter(code.begin(), code.end(), pattern); + const std::sregex_iterator end; + + // TODO: parse relative paths as well + for (; iter != end; ++ iter) { + const auto include_str = iter->str(); + const int len = include_str.length(); + if (include_str.substr(0, 10) == "#include <" and include_str[len - 1] == '>' and include_str[10] != ' ' and include_str[len - 2] != ' ') { + std::string filename = include_str.substr(10, len - 11); + if (filename.substr(0, 9) == "deep_gemm") // We only parse `` + includes.push_back(filename); + } else { + std::string error_info = fmt::format("Non-standard include: {}", include_str); + if (file_path != "") + error_info += fmt::format(" ({})", file_path.string()); + DG_HOST_UNREACHABLE(error_info); + } + } + return includes; + } + +public: + static std::filesystem::path library_include_path; + + static void prepare_init(const std::string& library_root_path) { + library_include_path = std::filesystem::path(library_root_path) / "include"; + } + + std::string get_hash_value(const std::string& code, const bool& exclude_code = true) { + std::stringstream ss; + for (const auto& i: get_includes(code)) + ss << get_hash_value_by_path(library_include_path / i) << "$"; + if (not exclude_code) + ss << "#" << get_hex_digest(code); + return get_hex_digest(ss.str()); + } + + std::string get_hash_value_by_path(const std::filesystem::path& path) { + // Check whether hit in cache + // ReSharper disable once CppUseAssociativeContains + if (cache.count(path) > 0) { + const auto opt = cache[path]; + if (not opt.has_value()) + DG_HOST_UNREACHABLE(fmt::format("Circular include may occur: {}", path.string())); + return opt.value(); + } + + // Read file and calculate hash recursively + std::ifstream in(path); + if (not in.is_open()) + DG_HOST_UNREACHABLE(fmt::format("Failed to open: {}", path.string())); + std::string code((std::istreambuf_iterator(in)), std::istreambuf_iterator()); + cache[path] = std::nullopt; + return (cache[path] = get_hash_value(code, false)).value(); + } +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(IncludeParser, library_include_path); + +static auto include_parser = std::make_shared(); + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit/kernel_runtime.hpp b/third_party/DeepGEMM/csrc/jit/kernel_runtime.hpp new file mode 100644 index 00000000..40597fb4 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit/kernel_runtime.hpp @@ -0,0 +1,165 @@ +#pragma once + +#include + +#include "../utils/exception.hpp" +#include "../utils/format.hpp" +#include "../utils/system.hpp" +#include "device_runtime.hpp" +#include "handle.hpp" +#include "include_parser.hpp" + +namespace deep_gemm { + +struct LaunchArgs { + std::pair grid_dim; + int num_threads; + int smem_size; + int cluster_dim; + bool enable_pdl; + + LaunchArgs(const int& grid_dim_x, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true): + grid_dim({grid_dim_x, 1}), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {} + + LaunchArgs(const std::pair& grid_dim, const int& num_threads, const int& smem_size = 0, const int& cluster_dim = 1, const bool& enable_pdl = true): + grid_dim(grid_dim), num_threads(num_threads), smem_size(smem_size), cluster_dim(cluster_dim), enable_pdl(enable_pdl) {} +}; + +class KernelRuntime final { +public: + static std::filesystem::path cuda_home; + + LibraryHandle library; + KernelHandle kernel; + + explicit KernelRuntime(const std::filesystem::path& dir_path) { + // Check `prepare_init` + DG_HOST_ASSERT(not cuda_home.empty()); + + // NOLINT(*-pro-type-member-init) + const auto cuobjdump_path = cuda_home / "bin" / "cuobjdump"; + const auto cubin_path = dir_path / "kernel.cubin"; + if (get_env("DG_JIT_DEBUG")) + printf("Loading CUBIN: %s\n", cubin_path.c_str()); + + // Record start time + std::chrono::high_resolution_clock::time_point start_time; + if (get_env("DG_JIT_DEBUG") or get_env("DG_JIT_PRINT_LOAD_TIME")) + start_time = std::chrono::high_resolution_clock::now(); + +#ifdef DG_JIT_USE_LIBRARY_ENUM_KERNELS + // Load from the library + kernel = load_kernel(cubin_path, {}, &library); +#else + // Find the only symbol + // TODO: use kernel enumeration for newer drivers + const std::vector illegal_names = {"vprintf", "__instantiate_kernel", "__internal", "__assertfail"}; + const auto [exit_code, symbols] = call_external_command(fmt::format("{} -symbols {}", cuobjdump_path.c_str(), cubin_path.c_str())); + DG_HOST_ASSERT(exit_code == 0); + std::istringstream iss(symbols); + std::vector symbol_names; + for (std::string line; std::getline(iss, line); ) { + if (line.find("STT_FUNC") == 0 and line.find("STO_ENTRY") != std::string::npos and + std::none_of(illegal_names.begin(), illegal_names.end(), + [&](const auto name) { return line.find(name) != std::string::npos; })) { + const auto last_space = line.rfind(' '); + symbol_names.push_back(line.substr(last_space + 1)); + } + } + + // Print symbols + if (symbol_names.size() != 1 or get_env("DG_JIT_DEBUG")) { + printf("Symbols: "); + printf(" > CUBIN: %s\n", cubin_path.c_str()); + printf(" > Raw symbols: %s\n", symbols.c_str()); + printf(" > Parsed symbols:\n"); + for (const auto& symbol: symbol_names) + printf(" > %s, ", symbol.c_str()); + } + DG_HOST_ASSERT(symbol_names.size() == 1); + + // Load from the library + kernel = load_kernel(cubin_path, symbol_names[0], &library); +#endif + + // Print load time + if (get_env("DG_JIT_DEBUG") or get_env("DG_JIT_PRINT_LOAD_TIME")) { + std::chrono::duration load_time = std::chrono::high_resolution_clock::now() - start_time; + printf("Load time (%s): %.2lf ms\n", dir_path.c_str(), load_time.count()); + } + } + + static void prepare_init(const std::string& cuda_home_path_by_python) { + cuda_home = cuda_home_path_by_python; + } + + static bool check_validity(const std::filesystem::path& dir_path) { + if (not std::filesystem::exists(dir_path)) + return false; + + // NOTES: if the directory exists, `kernel.cu` and `kernel.cubin` must both exist, + // because the directory is created atomically via rename + if (not std::filesystem::exists(dir_path / "kernel.cu") or + not std::filesystem::exists(dir_path / "kernel.cubin")) { + printf("Corrupted JIT cache directory (missing kernel.cu or kernel.cubin): %s, " + "please run `rm -rf %s` and restart your task.\n", + dir_path.c_str(), dir_path.c_str()); + DG_HOST_ASSERT(false and "Corrupted JIT cache directory"); + } + return true; + } + + ~KernelRuntime() noexcept(false) { + unload_library(library); + } +}; + +DG_DECLARE_STATIC_VAR_IN_CLASS(KernelRuntime, cuda_home); + +template +class LaunchRuntime { +public: + template + static std::string generate(const Args& args) { + auto code = Derived::generate_impl(args); + + // NOTES: we require that `generate_impl`'s includes never change + static std::string include_hash; + if (include_hash.empty()) + include_hash = include_parser->get_hash_value(code); + + // TODO: optimize string concat performance + code = fmt::format("// Includes' hash value: {}\n{}", include_hash, code); + if (get_env("DG_JIT_DEBUG")) + printf("Generated kernel code:\n%s\n", code.c_str()); + return code; + } + + template + static void launch(const std::shared_ptr& kernel_runtime, const Args& args) { + const auto kernel = kernel_runtime->kernel; + const auto stream = at::cuda::getCurrentCUDAStream(); + LaunchArgs launch_args = args.launch_args; + + // Allow runtime override from Python. + // NOTES: the default is enabled. + launch_args.enable_pdl = device_runtime->get_pdl(); + + const dim3 grid_dim = {static_cast(launch_args.grid_dim.first), + static_cast(launch_args.grid_dim.second), + 1}; + const dim3 block_dim = {static_cast(launch_args.num_threads), 1, 1}; + auto config = construct_launch_config(kernel, stream, launch_args.smem_size, + grid_dim, block_dim, launch_args.cluster_dim, launch_args.enable_pdl); + + // Launch in the derived class + if (get_env("DG_JIT_DEBUG")) { + printf("Launch kernel with {%d, %d} x %d, shared memory: %d bytes, cluster: %d, pdl: %d, stream: %ld\n", + launch_args.grid_dim.first, launch_args.grid_dim.second, launch_args.num_threads, + launch_args.smem_size, launch_args.cluster_dim, launch_args.enable_pdl, stream.id()); + } + Derived::launch_impl(kernel, config, args); + } +}; + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/heuristics/common.hpp b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/common.hpp new file mode 100644 index 00000000..2b79a8b7 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/common.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include +#include + +#include "config.hpp" +#include "runtime.hpp" +#include "../../utils/layout.hpp" +#include "../../utils/system.hpp" + +namespace deep_gemm { + +template +static GemmConfig get_best_config(const GemmDesc& desc) { + desc.check_validity(); + + // Choose the best layout + const auto layout_candidates = ArchSpec::get_layout_candidates(desc); + DG_HOST_ASSERT(not layout_candidates.empty()); + auto layout = layout_candidates[0]; + auto layout_info = ArchSpec::get_layout_info(desc, layout); + for (int i = 1; i < static_cast(layout_candidates.size()); ++ i) { + const auto candidate_info = ArchSpec::get_layout_info(desc, layout_candidates[i]); + if (ArchSpec::compare(candidate_info, layout_info)) + layout = layout_candidates[i], layout_info = candidate_info; + } + + // Infer other configs + const auto storage_config = ArchSpec::get_storage_config(desc, layout); + const auto pipeline_config = ArchSpec::get_pipeline_config(desc, layout, storage_config); + const auto launch_config = ArchSpec::get_launch_config(desc, layout); + const auto gemm_config = GemmConfig { + .layout = layout, + .storage_config = storage_config, + .pipeline_config = pipeline_config, + .launch_config = launch_config + }; + + // Print configs for the first time + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + std::stringstream ss; + ss << desc; + const auto key = ss.str(); + + static std::unordered_set printed; + if (printed.count(key) == 0) { + std::cout << desc << ": " << gemm_config << ", " << layout_info << std::endl; + printed.insert(key); + } + } + return gemm_config; +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/heuristics/config.hpp b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/config.hpp new file mode 100644 index 00000000..c06f2f16 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/config.hpp @@ -0,0 +1,171 @@ +#pragma once + +#include +#include +#include + +#include "../../utils/math.hpp" + +namespace deep_gemm { + +/// GEMM descriptors +struct GemmDesc { + GemmType gemm_type; + KernelType kernel_type; + int m, n, k, num_groups; + at::ScalarType a_dtype, b_dtype, cd_dtype; + cute::UMMA::Major major_a; + cute::UMMA::Major major_b; + bool with_accumulation; + + // Requirements from users + int num_sms, tc_util; + std::string compiled_dims; + + // Shape for heuristic generation + int expected_m = 0, expected_n = 0, expected_k = 0, expected_num_groups = 0; + int get_expected_m() const { return expected_m > 0 ? expected_m : m; } + int get_expected_n() const { return expected_n > 0 ? expected_n : n; } + int get_expected_k() const { return expected_k > 0 ? expected_k : k; } + int get_expected_num_groups() const { return expected_num_groups > 0 ? expected_num_groups : num_groups; } + + MmaKind get_mma_kind() const { + return a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4; + } + + void check_validity() const { + if (get_mma_kind() == MmaKind::BF16) { + DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16); + } else { + DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4); + DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4); + } + DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); + DG_HOST_ASSERT(num_sms % 2 == 0); + } + + friend std::ostream& operator << (std::ostream& os, const GemmDesc& desc) { + MmaKind mma_kind = desc.get_mma_kind(); + os << "GemmDesc(gemm_type=" << static_cast(desc.gemm_type) + << ", kernel_type=" << static_cast(desc.kernel_type) + << ", m=" << desc.m << ", n=" << desc.n << ", k=" << desc.k + << ", num_groups=" << desc.num_groups + << ", major_a=" << static_cast(desc.major_a) + << ", major_b=" << static_cast(desc.major_b) + << ", mma_kind=" << static_cast(mma_kind) + << ", a_dtype=" << c10::toString(desc.a_dtype) + << ", b_dtype=" << c10::toString(desc.b_dtype) + << ", cd_dtype=" << c10::toString(desc.cd_dtype) + << ", with_accumulation=" << static_cast(desc.with_accumulation) + << ", num_sms=" << desc.num_sms + << ", tc_util=" << desc.tc_util + << ", compiled_dims=" << desc.compiled_dims + << ", expected_m=" << desc.expected_m + << ", expected_n=" << desc.expected_n + << ", expected_k=" << desc.expected_k + << ", expected_num_groups=" << desc.expected_num_groups << ")"; + return os; + } +}; + +/// GEMM configs +struct Layout { + int swap_ab; + int block_m, block_n, block_k; + int cluster_m, cluster_n; + + int get_cluster_size() const { + return cluster_m * cluster_n; + } + + friend std::ostream& operator << (std::ostream& os, const Layout& layout) { + os << "Layout(swap_ab=" << layout.swap_ab + << ", block_m=" << layout.block_m << ", block_n=" << layout.block_n << ", block_k=" << layout.block_k + << ", cluster_m=" << layout.cluster_m << ", cluster_n=" << layout.cluster_n << ")"; + return os; + } +}; + +struct StorageConfig { + int load_block_m, load_block_n; + int store_block_m, store_block_n; + + int swizzle_a_mode, swizzle_b_mode; + int swizzle_cd_mode; + + friend std::ostream& operator << (std::ostream& os, const StorageConfig& config) { + os << "StorageConfig(" + << "load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n + << ", store_block_m=" << config.store_block_m << ", store_block_n=" << config.store_block_n + << ", swizzle_a_mode=" << config.swizzle_a_mode << ", swizzle_b_mode=" << config.swizzle_b_mode + << ", swizzle_cd_mode=" << config.swizzle_cd_mode << ")"; + return os; + } +}; + +struct PipelineConfig { + int smem_size; + int num_stages; + + friend std::ostream& operator << (std::ostream& os, const PipelineConfig& config) { + os << "PipelineConfig(" + << "smem_size=" << config.smem_size + << ", num_stages=" << config.num_stages << ")"; + return os; + } +}; + +struct LaunchConfig { + int num_sms; + int num_sms_per_cluster; + int num_threads; + + int num_tma_threads; + int num_math_threads; + int num_non_epilogue_threads; + int num_epilogue_threads; + + friend std::ostream& operator << (std::ostream& os, const LaunchConfig& config) { + os << "LaunchConfig(" + << "num_sms=" << config.num_sms << ", num_sms_per_cluster=" << config.num_sms_per_cluster + << ", num_threads=" << config.num_threads + << ", num_tma_threads=" << config.num_tma_threads << ", num_math_threads=" << config.num_math_threads + << ", num_non_epilogue_threads=" << config.num_non_epilogue_threads + << ", num_epilogue_threads=" << config.num_epilogue_threads << ")"; + return os; + } +}; + +struct GemmConfig { + Layout layout; + StorageConfig storage_config; + PipelineConfig pipeline_config; + LaunchConfig launch_config; + + friend std::ostream& operator << (std::ostream& os, const GemmConfig& config) { + os << "GemmConfig(" + << "layout=" << config.layout + << ", storage_config=" << config.storage_config + << ", pipeline_config=" << config.pipeline_config + << ", launch_config=" << config.launch_config << ")"; + return os; + } +}; + +/// Config comparators +struct LayoutInfo { + int num_waves; + int last_wave_util; + int64_t num_cycles; + Layout layout; + + friend std::ostream& operator << (std::ostream& os, const LayoutInfo& config) { + os << "LayoutInfo(" + << "num_waves=" << config.num_waves + << ", last_wave_util=" << config.last_wave_util + << ", num_cycles=" << config.num_cycles << ")"; + return os; + } +}; + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/heuristics/mega_moe.hpp b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/mega_moe.hpp new file mode 100644 index 00000000..b1ba6bd7 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/mega_moe.hpp @@ -0,0 +1,240 @@ +#pragma once + +#include +#include + +#include + +#include "../../utils/exception.hpp" +#include "../../utils/math.hpp" +#include "../../utils/system.hpp" +#include "sm100.hpp" + +namespace deep_gemm { + +struct MegaMoEConfig { + // Block tiling + int block_m, block_n, block_k; + int load_block_m, load_block_n; + int store_block_m; + + // SF block sizes (UTCCP 128-aligned) + int sf_block_m, sf_block_n; + + // Pool capacity and SF-padded token count + int num_max_pool_tokens; + int num_padded_sf_pool_tokens; + + // Swizzle modes for TMA descriptors + int swizzle_acts_mode, swizzle_weights_mode; + + // Number of experts to process per wave + int num_experts_per_wave; + + // Pipeline stages and shared memory + int num_stages, smem_size; + + // Thread layout + int num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads; + + friend std::ostream& operator << (std::ostream& os, const MegaMoEConfig& config) { + os << "MegaMoEConfig(" + << "block_m=" << config.block_m << ", block_n=" << config.block_n << ", block_k=" << config.block_k + << ", load_block_m=" << config.load_block_m << ", load_block_n=" << config.load_block_n + << ", store_block_m=" << config.store_block_m + << ", sf_block_m=" << config.sf_block_m << ", sf_block_n=" << config.sf_block_n + << ", num_max_pool_tokens=" << config.num_max_pool_tokens + << ", num_padded_sf_pool_tokens=" << config.num_padded_sf_pool_tokens + << ", swizzle_acts_mode=" << config.swizzle_acts_mode << ", swizzle_weights_mode=" << config.swizzle_weights_mode + << ", num_experts_per_wave=" << config.num_experts_per_wave + << ", num_stages=" << config.num_stages << ", smem_size=" << config.smem_size + << ", num_dispatch_threads=" << config.num_dispatch_threads + << ", num_non_epilogue_threads=" << config.num_non_epilogue_threads + << ", num_epilogue_threads=" << config.num_epilogue_threads << ")"; + return os; + } +}; + +static std::tuple get_block_config_for_mega_moe( + const int& num_ranks, const int& num_experts, + const int& num_max_tokens_per_rank, const int& num_topk, + const int& num_tokens) { + const auto& [cluster_size, block_m, store_block_m, num_epilogue_warpgroups] = [&]() -> std::tuple { + float num_expected_tokens_per_expert = static_cast(num_tokens) * num_ranks * num_topk / num_experts; + if (num_expected_tokens_per_expert <= 8.5) { + // Really small token-per-expert (e.g. RL long-tail rollout), use the smallest block_m + return {2, 16, 8, 2}; + } else if (num_expected_tokens_per_expert <= 16.5) { + // Small batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 128 + return {2, 32, 16, 2}; + } else if (num_expected_tokens_per_expert <= 32.5) { + // Medium batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 256 + return {2, 64, 32, 1}; + } else if (num_expected_tokens_per_expert <= 64.5) { + // Large batch size, small EP, decoding, e.g. 6/384 experts, EP8, bsz 512 + return {2, 96, 16, 2}; + } else if (num_expected_tokens_per_expert <= 96.5) { + // Medium batch size, Medium EP, decoding, e.g. 6/384 experts, EP16, bsz 256, or EP32, bsz128 + return {2, 128, 32, 2}; + } else { + // Prefill, or large EP decoding + return {2, 192, 32, 2}; + } + }(); + + // Check whether our `block_m` lies in `kCandidateBlockM` + DG_HOST_ASSERT(std::any_of( + layout::kCandidateBlockM, layout::kCandidateBlockM + layout::kNumCandidateBlockMs, + [=](const auto& candidate) { return candidate == block_m; }) + ); + + // Return configs + return {cluster_size, block_m, store_block_m, num_epilogue_warpgroups * 128}; +} + +static int get_num_experts_per_wave_for_mega_moe( + const int& num_experts_per_rank, const int& num_tokens, const int& num_topk, + const int& intermediate_hidden, const int& block_m, const int& block_n, const int& num_sms) { + + float expected_tokens_per_expert = static_cast(num_tokens) * num_topk / num_experts_per_rank; + if (expected_tokens_per_expert < 1) { + // Most experts don't have tokens, calculate all experts at once + return num_experts_per_rank; + } + + // Reduce per-expert block count by this factor since uneven routing leaves some experts with fewer tokens + constexpr int kImbalanceFactor = 2; + + // Count L1 blocks per expert assuming tokens are evenly spread across experts + const int num_m_blocks = ceil_div(static_cast(std::ceil(expected_tokens_per_expert)), block_m); + const int num_n_blocks = (2 * intermediate_hidden) / block_n; + const int num_l1_blocks_per_expert = num_m_blocks * num_n_blocks; + + // Pick the smallest value whose total blocks (after imbalance reduction) can keep all SMs busy + int num_experts_per_wave = num_l1_blocks_per_expert > 0 + ? ceil_div(kImbalanceFactor * num_sms, num_l1_blocks_per_expert) : 1; + num_experts_per_wave = std::min(num_experts_per_wave, num_experts_per_rank); + + // Round up to the nearest divisor of num_experts_per_rank so every wave processes the same count + while (num_experts_per_wave < num_experts_per_rank and num_experts_per_rank % num_experts_per_wave != 0) + ++ num_experts_per_wave; + + return num_experts_per_wave; +} + +static std::pair get_pipeline_config_for_mega_moe( + const int& smem_capacity, + const int& num_experts, const int& hidden, + const int& block_m, const int& block_n, const int& block_k, const int& store_block_m, + const int& sf_block_m, const int& sf_block_n, + const int& num_dispatch_warps, const int& num_epilogue_warps) { + constexpr int kSmemAlignment = 1024; + constexpr int kNumEpilogueStages = 2; + constexpr int kNumTMAStoreStages = 2; + + // Always multicast on A + const int load_block_m = block_m / 2; + + // Dispatch region + const int smem_expert_count_size = align( + num_experts * static_cast(sizeof(uint32_t)), kSmemAlignment); + const int smem_send_buffers_size = align( + static_cast(layout::Buffer(layout::Data(hidden), num_dispatch_warps, 1).get_num_bytes()), + kSmemAlignment); + const int smem_dispatch_size = smem_expert_count_size + smem_send_buffers_size; + + // C/D output region: max of L1 FP8 (2 TMA stages, BLOCK_N/2 post-SwiGLU) and L2 BF16 (1 stage) + const auto num_epilogue_warpgroups = num_epilogue_warps / 4; + const int smem_cd_l1 = num_epilogue_warpgroups * store_block_m * (block_n / 2) * kNumTMAStoreStages; + const int smem_cd_l2 = num_epilogue_warpgroups * store_block_m * block_n * static_cast(sizeof(nv_bfloat16)); + const int smem_cd = std::max(smem_cd_l1, smem_cd_l2); + + // Barriers (stage-independent): dispatch + tensor memory full/empty + combine (2 per epilogue warp) + const int smem_barriers = (num_dispatch_warps + kNumEpilogueStages * 2 + num_epilogue_warps * 2) * 8; + + // Amax reduction + const int smem_amax_reduction = store_block_m * num_epilogue_warps * static_cast(sizeof(float)); + + // Tensor memory pointer + const int smem_tmem_ptr = 4; + + // SF is aligned to UTCCP 128-element granularity + const int smem_sfa_per_stage = sf_block_m * 4; + const int smem_sfb_per_stage = sf_block_n * 4; + + // Per-stage: A tile + B tile + SFA tile + SFB tile + full/empty barriers + const int smem_per_stage = load_block_m * block_k + block_n * block_k + smem_sfa_per_stage + smem_sfb_per_stage + 2 * 8; + + // Fixed total + const int smem_fixed = smem_dispatch_size + smem_cd + smem_amax_reduction + smem_barriers + smem_tmem_ptr; + + // Select maximum num_stages + const int num_stages = (smem_capacity - smem_fixed) / smem_per_stage; + DG_HOST_ASSERT(num_stages >= 2); + + return {num_stages, smem_fixed + num_stages * smem_per_stage}; +} + +static MegaMoEConfig get_mega_moe_config( + const int& num_ranks, const int& num_experts, const int& num_experts_per_rank, + const int& num_max_tokens_per_rank, const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const int& num_padded_sf_pool_tokens) { + // Block config + const auto [cluster_size, block_m, store_block_m, num_epilogue_threads] = + get_block_config_for_mega_moe(num_ranks, num_experts, num_max_tokens_per_rank, num_topk, num_tokens); + const int block_n = 128; + const int block_k = 128; + const int load_block_m = block_m / 2; + const int load_block_n = block_n; + const auto [sf_block_m, sf_block_n] = SM100ArchSpec::get_sf_uttcp_aligned_block_sizes(block_m, block_n, MmaKind::MXFP8FP4); + const int num_max_pool_tokens = layout::get_num_max_pool_tokens( + num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + // NOTES: FP8 activations and FP4 weights (unpacked to 8-bit in smem) both use 128B swizzle + const int swizzle_acts_mode = 128; + const int swizzle_weights_mode = 128; + + // Waves + const int num_sms = device_runtime->get_num_sms(); + const int num_experts_per_wave = get_num_experts_per_wave_for_mega_moe( + num_experts_per_rank, num_tokens, num_topk, + intermediate_hidden, block_m, block_n, num_sms); + + // Thread layout + const int num_dispatch_threads = 128; + const int num_non_epilogue_threads = 128; + + // Pipeline + const auto [num_stages, smem_size] = get_pipeline_config_for_mega_moe( + SM100ArchSpec::smem_capacity, + num_experts, hidden, + block_m, block_n, block_k, store_block_m, + sf_block_m, sf_block_n, + num_dispatch_threads / 32, num_epilogue_threads / 32); + + const auto config = MegaMoEConfig { + block_m, block_n, block_k, + load_block_m, load_block_n, store_block_m, + sf_block_m, sf_block_n, + num_max_pool_tokens, num_padded_sf_pool_tokens, + swizzle_acts_mode, swizzle_weights_mode, + num_experts_per_wave, + num_stages, smem_size, + num_dispatch_threads, num_non_epilogue_threads, num_epilogue_threads + }; + + // Print configs for the first time + if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { + const auto key = fmt::format( + "MegaMoEConfig(num_ranks={}, num_experts={}, hidden={}, intermediate_hidden={}, num_max_tokens_per_rank={}, num_tokens={}, num_topk={})", + num_ranks, num_experts, hidden, intermediate_hidden, num_max_tokens_per_rank, num_tokens, num_topk); + static std::unordered_set printed; + if (printed.count(key) == 0) { + std::cout << key << ": " << config << std::endl; + printed.insert(key); + } + } + return config; +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/heuristics/runtime.hpp b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/runtime.hpp new file mode 100644 index 00000000..93f2a23a --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/runtime.hpp @@ -0,0 +1,62 @@ +#pragma once + +#include "../../jit/device_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/lazy_init.hpp" + +namespace deep_gemm { + +class HeuristicsRuntime { + static constexpr int kLegacyMKAlignmentForContiguousLayout = 128; + + bool ignore_compile_dims = false; + int block_m_multiple_of = 1; + int block_n_multiple_of = 1; + int mk_alignment_for_contiguous_layout = kLegacyMKAlignmentForContiguousLayout; + +public: + void set_ignore_compile_dims(const bool& new_value) { + ignore_compile_dims = new_value; + } + + bool get_ignore_compile_dims() const { + return ignore_compile_dims; + } + + void set_block_size_multiple_of(const int& new_block_m_multiple_of, const int& new_block_n_multiple_of) { + block_m_multiple_of = new_block_m_multiple_of; + block_n_multiple_of = new_block_n_multiple_of; + } + + int get_block_m_multiple_of() const { + return block_m_multiple_of; + } + + int get_block_n_multiple_of() const { + return block_n_multiple_of; + } + + void set_mk_alignment_for_contiguous_layout(const int& new_value) { + mk_alignment_for_contiguous_layout = new_value; + } + + int get_mk_alignment_for_contiguous_layout() const { + return mk_alignment_for_contiguous_layout; + } + + static int get_theoretical_mk_alignment_for_contiguous_layout(const std::optional& expected_m) { + if (device_runtime->get_arch_major() != 10) + return kLegacyMKAlignmentForContiguousLayout; + + int block_m = 240, mma_step = 16; + if (expected_m.has_value()) { + // Reduce `block_m` while ensuring it covers `m` + for (; block_m > 32 and block_m - mma_step >= expected_m.value(); block_m -= mma_step); + } + return block_m; + } +}; + +static auto heuristics_runtime = LazyInit([](){ return std::make_shared(); }); + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/heuristics/sm100.hpp b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/sm100.hpp new file mode 100644 index 00000000..c8e9e2e0 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/sm100.hpp @@ -0,0 +1,269 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" +#include "runtime.hpp" +#include "utils.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +struct SM100ArchSpec { + static constexpr int smem_capacity = 232448; + + static std::pair get_sf_uttcp_aligned_block_sizes( + const int& block_m, const int& block_n, const MmaKind& mma_kind) { + constexpr int num_utccp_aligned_elems = 128; + switch (mma_kind) { + case MmaKind::BF16: return {0, 0}; + case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; + default: DG_HOST_UNREACHABLE("Unknown dtype"); + } + } + + static std::vector get_layout_candidates(const GemmDesc& desc) { + // Block K is always in a fixed manner + const int block_k = 128 / get_element_size(desc.get_mma_kind()); + + // Always enable swap A/B (and multicasting if possible) for m-grouped GEMMs + if (desc.gemm_type == GemmType::MGroupedContiguous or + desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout or + desc.gemm_type == GemmType::MGroupedMasked) { + const bool swap_ab = true; + const auto block_n = 128; + const auto block_m = heuristics_runtime->get_mk_alignment_for_contiguous_layout(); + const auto cluster_m = 1; + const auto cluster_n = ceil_div(desc.n, block_n) % 2 == 0 and desc.num_sms % 2 == 0 ? 2 : 1; + const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n}; + std::vector candidates = {layout}; + return candidates; + } + + // Enumerate all candidates + std::vector candidates; + for (int swap_ab = 0; swap_ab < 2; ++ swap_ab) { + // Block M/N candidates + std::vector block_m_candidates; + std::vector block_n_candidates; + if (swap_ab) { + int step = std::lcm(16, heuristics_runtime->get_block_m_multiple_of()); + int end = 256; + for (int i = step; i <= end; i += step) + block_m_candidates.push_back(i); + + // TODO: consider other block N + block_n_candidates = {128}; + } else { + // NOTES: smaller block M can avoid TMA L2 OOB bound + // TODO: consider block M = 256 + if (desc.m <= 32) block_m_candidates = {32}; + else if (desc.m <= 64) block_m_candidates = {64}; + else block_m_candidates = {128}; + + // Small block size for small shape + if (16 % heuristics_runtime->get_block_n_multiple_of() == 0) + block_n_candidates.push_back(16); + int step = std::lcm(32, heuristics_runtime->get_block_n_multiple_of()); + // For small K, fewer store blocks improve store/compute overlap and reduce epilogue bottleneck + int end = desc.k <= 256 ? 128 : 256; + for (int i = step; i <= end; i += step) + block_n_candidates.push_back(i); + } + + for (int cluster_m = 1; cluster_m <= 2; ++ cluster_m) { + // After swapping, layout A/D can only do on cluster N + if (swap_ab == 1 and cluster_m > 1) + continue; + + for (int cluster_n = 1; cluster_n <= 2; ++ cluster_n) { + // We only support cluster 2 + if (cluster_m * cluster_n > 2) + continue; + + // Only support layout A/D + if (swap_ab == 0 and cluster_n > 1) + continue; + + // SM count must be divisible + if (desc.num_sms % (cluster_m * cluster_n) != 0) + continue; + + for (int block_m: block_m_candidates) { + // Ensure large swizzle sizes (32B swizzle yields poor performance) + const auto swizzle_a_requirement = desc.a_dtype == kPackedFP4 ? 128 : 64; + // Enforce swizzle alignment for MN major; otherwise check base MMA shape + const auto load_block_m_requirement = desc.major_a == cute::UMMA::Major::MN ? swizzle_a_requirement : 8; + if ((block_m / cluster_n) % load_block_m_requirement != 0) + continue; + + // Shape must be divisible for multicast + if (ceil_div(desc.m, block_m) % cluster_m != 0) + continue; + + for (int block_n: block_n_candidates) { + // Ensure large swizzle sizes (32B swizzle yields poor performance) + const auto swizzle_b_requirement = desc.b_dtype == kPackedFP4 ? 128 : 64; + // Enforce swizzle alignment for MN major; otherwise check base MMA shape + const auto load_block_n_requirement = desc.major_b == cute::UMMA::Major::MN ? swizzle_b_requirement : 8; + if ((block_n / cluster_m) % load_block_n_requirement != 0) + continue; + + // Shape must be divisible for multicast + if (ceil_div(desc.n, block_n) % cluster_n != 0) + continue; + + // SwapAB requires block N is layout A/D' UMMA M + constexpr int layout_ad_m = 128; + if (swap_ab and block_n != layout_ad_m) + continue; + + // Check tensor memory capacity + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, desc.get_mma_kind()); + const auto tmem_sf_cols = desc.get_mma_kind() == MmaKind::MXFP8FP4 ? sf_block_m / 32 + sf_block_n / 32 : 0; + const auto umma_n = swap_ab ? block_m : block_n; + if (2 * umma_n + tmem_sf_cols > 512) + continue; + + const auto layout = Layout{swap_ab, block_m, block_n, block_k, cluster_m, cluster_n}; + + // When neither A nor B is MN major, 128B swizzle is always feasible + if (desc.major_a == cute::UMMA::Major::K or desc.major_b == cute::UMMA::Major::K) { + const auto storage_config = get_storage_config(desc, layout); + if (storage_config.swizzle_a_mode != 128 or storage_config.swizzle_b_mode != 128) + continue; + } + + candidates.push_back(layout); + } + } + } + } + } + + DG_HOST_ASSERT(not candidates.empty()); + return candidates; + } + + static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) { + constexpr int layout_ad_m = 128; + constexpr int umma_step_n = 16; + + // Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms) + const auto load_block_m = layout.block_m / layout.cluster_n; + const auto load_block_n = layout.block_n / layout.cluster_m; + const auto store_block_m = layout.swap_ab ? umma_step_n : std::min(layout_ad_m, layout.block_m); + const auto store_block_n = layout.block_n; + + // Decide swizzling by the inner dim + // TODO: support FP4 sub-byte + const auto swizzle_mode_a = get_swizzle_mode( + desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype)); + const auto swizzle_mode_b = get_swizzle_mode( + desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype)); + const auto swizzle_mode_cd = get_swizzle_mode( + store_block_n, c10::elementSize(desc.cd_dtype)); + + return { + load_block_m, load_block_n, + store_block_m, store_block_n, + swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd + }; + } + + static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) { + constexpr int kNumMaxStages = 32; + + // C/D for TMA stores + const int smem_cd = layout.swap_ab ? storage_config.store_block_m * storage_config.store_block_n * c10::elementSize(desc.cd_dtype) * 2 + : storage_config.store_block_m * storage_config.swizzle_cd_mode * 2; + + // TODO: remove SF barriers for BF16 GEMMs + // TMA full/empty barriers, with-SF full barriers, tensor memory full/empty barriers + // NOTES: some shapes may only have 1 epilogue stage, but we still allocate space for 2 stages + // NOTES: the last barrier is for tensor core utilization control + const int smem_barriers = kNumMaxStages * 8 * 3 + 2 * 8 * 2 + 8; + + // Tensor memory pointer + const int smem_tmem_ptr = 4; + + // Calculate A/B per stages + // TODO: consider FP4 + const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype); + const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype); + + // Calculate SF A/B per stages + int smem_sfa_per_stage = 0; + int smem_sfb_per_stage = 0; + if (desc.kernel_type == KernelType::Kernel1D1D) { + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes( + layout.block_m, layout.block_n, desc.get_mma_kind()); + smem_sfa_per_stage = sf_block_m * 4; + smem_sfb_per_stage = sf_block_n * 4; + } + + // Calculate stages + int smem_extra = smem_cd + smem_barriers + smem_tmem_ptr; + int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage; + int num_stages = std::min( + (smem_capacity - smem_extra) / smem_per_stage, + kNumMaxStages); + return { + smem_extra + num_stages * smem_per_stage, + num_stages + }; + } + + static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) { + return { + desc.num_sms, + layout.get_cluster_size(), + 256, + 32, 128, 128, 128 + }; + } + + static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) { + const auto num_blocks = + ceil_div(desc.get_expected_m(), layout.block_m) * + ceil_div(desc.get_expected_n(), layout.block_n) * + desc.get_expected_num_groups(); + const auto num_waves = ceil_div(num_blocks, desc.num_sms); + const auto num_last_blocks = num_blocks % desc.num_sms; + const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks; + // TODO: calculate expected cycles + return {num_waves, last_wave_util, 0, layout}; + } + + // A regular comparator + static bool compare(const LayoutInfo& a, const LayoutInfo& b) { + // Single wave is always better + if ((a.num_waves == 1 or b.num_waves == 1) and a.num_waves != b.num_waves) + return a.num_waves < b.num_waves; + + // Doing multicast is better + if (a.layout.get_cluster_size() != b.layout.get_cluster_size()) + return a.layout.get_cluster_size() > b.layout.get_cluster_size(); + + // Smaller number of waves is better + if (a.num_waves != b.num_waves) + return a.num_waves < b.num_waves; + + // Larger last wave utilization is better + if (a.last_wave_util != b.last_wave_util) + return a.last_wave_util > b.last_wave_util; + + // More stages is better + // Same block M, smaller block N is better + // Same block N, smaller block M is better + if (a.layout.block_m + a.layout.block_n != b.layout.block_m + b.layout.block_n) + return a.layout.block_m + a.layout.block_n < b.layout.block_m + b.layout.block_n; + + // Less shared memory C/D, more stages is better + return a.layout.block_m * a.layout.block_n < b.layout.block_m * b.layout.block_n; + } +}; + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/heuristics/sm90.hpp b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/sm90.hpp new file mode 100644 index 00000000..c411fb7e --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/sm90.hpp @@ -0,0 +1,246 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" +#include "utils.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +struct SM90ArchSpec { + static constexpr int smem_capacity = 232448; + + static std::vector get_layout_candidates(const GemmDesc& desc) { + // Block M candidates + std::vector block_m_candidates; + if (desc.gemm_type == GemmType::Normal or + desc.gemm_type == GemmType::Batched or + desc.gemm_type == GemmType::KGroupedContiguous) { + // TODO: check 256's performance + block_m_candidates = {64, 128}; + // NOTES: smaller block M can avoid TMA L2 OOB bound + if (desc.m <= 16) block_m_candidates.push_back(16); + if (desc.m <= 32) block_m_candidates.push_back(32); + + // BF16 output GEMM supports 256 + if (desc.cd_dtype != torch::kFloat) + block_m_candidates.push_back(256); + } else if (desc.gemm_type == GemmType::MGroupedContiguous or + desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) { + block_m_candidates = std::vector{heuristics_runtime->get_mk_alignment_for_contiguous_layout()}; + } else if (desc.gemm_type == GemmType::MGroupedMasked) { + block_m_candidates = {64, 128}; + } + + // Block N candidates + std::vector block_n_candidates; + int step = std::lcm(16, heuristics_runtime->get_block_n_multiple_of()); + int start = step; + // Avoid bank conflicts for 1D1D kernel FP32 output + if (desc.kernel_type == KernelType::Kernel1D1D and desc.cd_dtype == torch::kFloat) { + DG_HOST_ASSERT(desc.major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(desc.major_b == cute::UMMA::Major::K); + start = 24; + block_n_candidates.push_back(16); + } + // Register spills + int end = 256; + if (desc.kernel_type == KernelType::Kernel1D2D) + end = 192; + if (desc.kernel_type == KernelType::Kernel1D1D) + end = 160; + // Enumerate + for (int i = start; i <= end; i += step) + block_n_candidates.push_back(i); + + // Block K is always in a fixed manner + const int block_k = 128 / get_element_size(desc.get_mma_kind()); + + // Disable multicast for performance + const bool disable_multicast = + // The number of k-groups is large (a heuristic) + (desc.gemm_type == GemmType::KGroupedContiguous and desc.num_groups > 4) or + // Not supported + (desc.gemm_type == GemmType::Batched); + + // Enumerate all candidates + std::vector candidates; + for (int cluster_m = 1; cluster_m <= (disable_multicast ? 1 : 2); ++ cluster_m) { + for (int cluster_n = 1; cluster_n <= (disable_multicast ? 1 : 2); ++ cluster_n) { + // We only support cluster 2 + if (cluster_m * cluster_n > 2) + continue; + + // SM count must be divisible + if (desc.num_sms % (cluster_m * cluster_n) != 0) + continue; + + for (int block_m: block_m_candidates) { + for (int block_n: block_n_candidates) { + // 1D2D kernel unroll requirement + if (desc.kernel_type == KernelType::Kernel1D2D and block_n > block_k and (block_n % (block_n - block_k) != 0 and block_k % (block_n - block_k) != 0)) + continue; + + // Multicast legality for masked layout + // TODO: add some comments about it + if ((desc.gemm_type == GemmType::MGroupedMasked or desc.gemm_type == GemmType::MGroupedContiguousWithPsumLayout) and + ceil_div(desc.n, block_n) % (cluster_m * cluster_n) != 0) + continue; + + // The block sizes cannot be too large (for enough registers), so at least one dim less than 128 + if (block_m > 128 and block_n > 128) + continue; + + // Calculate swizzling + const auto layout = Layout{0, block_m, block_n, block_k, cluster_m, cluster_n}; + const auto storage_config = get_storage_config(desc, layout); + + // Make sure swizzling is large enough (32B's performance is low) + if (storage_config.swizzle_a_mode % 64 != 0 or storage_config.swizzle_b_mode % 64 != 0) + continue; + + // To hide TMA latency, the stage count should be at least 3; for small matrices, at least 4 + int num_stages = get_pipeline_config(desc, layout, storage_config).num_stages; + if (num_stages < 3 or (block_m * block_n < 128 * 192 and num_stages < 4)) + continue; + + candidates.push_back(layout); + } + } + } + } + + DG_HOST_ASSERT(not candidates.empty()); + return candidates; + } + + static StorageConfig get_storage_config(const GemmDesc& desc, const Layout& layout) { + constexpr int wgmma_m = 64; + + // Load/store block sizes (w/o consideration of swizzling atoms, w/ consideration of loop atoms) + // TODO: support swap AB + DG_HOST_ASSERT(layout.swap_ab == 0); + const auto load_block_m = layout.block_m; + const auto load_block_n = layout.block_n; + // 1D1D kernel will do single warp-group stores + const auto store_block_m = desc.kernel_type == KernelType::Kernel1D1D ? wgmma_m : layout.block_m; + const auto store_block_n = layout.block_n; + + // Decide swizzling by the inner dim + const auto swizzle_mode_a = get_swizzle_mode( + desc.major_a == cute::UMMA::Major::K ? layout.block_k : load_block_m, c10::elementSize(desc.a_dtype)); + const auto swizzle_mode_b = get_swizzle_mode( + desc.major_b == cute::UMMA::Major::K ? layout.block_k : load_block_n, c10::elementSize(desc.b_dtype)); + // We only enable swizzling for non-FP32 outputs + const auto swizzle_mode_cd = desc.cd_dtype != torch::kFloat ? + get_swizzle_mode(store_block_n, c10::elementSize(desc.cd_dtype)) : 0; + + return { + load_block_m, load_block_n, + store_block_m, store_block_n, + swizzle_mode_a, swizzle_mode_b, swizzle_mode_cd + }; + } + + static PipelineConfig get_pipeline_config(const GemmDesc& desc, const Layout& layout, const StorageConfig& storage_config) { + constexpr int kNumMaxStages = 16; + + // TODO: consider swap AB + // C/D for TMA stores + // NOTES: 1024 is for TMA swizzling alignment requirement + const int smem_cd = + align(layout.block_m * layout.block_n * static_cast(c10::elementSize(desc.cd_dtype)), 1024); + const int smem_barriers = kNumMaxStages * 8 * 2; + + // Calculate A/B per stages + const int smem_a_per_stage = storage_config.load_block_m * layout.block_k * c10::elementSize(desc.a_dtype); + const int smem_b_per_stage = storage_config.load_block_n * layout.block_k * c10::elementSize(desc.b_dtype); + + // Calculate SF A/B per stages + const int smem_sfa_per_stage = desc.kernel_type == KernelType::KernelNoSF ? + 0 : align(layout.block_m * static_cast(sizeof(float)), 128); + const int smem_sfb_per_stage = desc.kernel_type != KernelType::Kernel1D1D ? + 0 : align(layout.block_n * static_cast(sizeof(float)), 128); + + // Extra SFB sizes for 1D2D kernels + const int use_uniform_sfb = layout.block_k % layout.block_n == 0 ? 1 : 2; + const int smem_extra_sfb = desc.kernel_type != KernelType::Kernel1D2D ? + 0 : align(ceil_div(desc.k, layout.block_k) * static_cast(sizeof(float)) * use_uniform_sfb, 8); + + // Extra tensormap for 1D1D kernels + const int smem_tensormap = + desc.gemm_type == GemmType::KGroupedContiguous ? 4 * static_cast(sizeof(CUtensorMap)) : 0; + + // Calculate stages + const int smem_extra = smem_cd + smem_barriers + smem_extra_sfb + smem_tensormap; + const int smem_per_stage = smem_a_per_stage + smem_b_per_stage + smem_sfa_per_stage + smem_sfb_per_stage; + const int num_stages = std::min( + (smem_capacity - smem_extra) / smem_per_stage, + kNumMaxStages); + return { + smem_extra + num_stages * smem_per_stage, + num_stages + }; + } + + static LaunchConfig get_launch_config(const GemmDesc& desc, const Layout& layout) { + const int num_tma_threads = 128; + const int num_math_threads = layout.block_m <= 64 ? 128 : 256; + return { + desc.num_sms, + layout.get_cluster_size(), + num_tma_threads + num_math_threads, + num_tma_threads, num_math_threads, + 0, 0 // Meaningless for SM90 + }; + } + + static LayoutInfo get_layout_info(const GemmDesc& desc, const Layout& layout) { + const auto num_blocks = + ceil_div(desc.get_expected_m(), layout.block_m) * + ceil_div(desc.get_expected_n(), layout.block_n) * + desc.get_expected_num_groups(); + const auto num_waves = ceil_div(num_blocks, desc.num_sms); + const auto num_last_blocks = num_blocks % desc.num_sms; + const auto last_wave_util = num_last_blocks == 0 ? desc.num_sms : num_last_blocks; + + // Utils + const int l2_bandwidth_per_cycle = std::min(64. * desc.num_sms, 8e6 / (1.3e3)); // B/cycle + const int l1_bandwidth_per_cycle = 128 * desc.num_sms; // B/cycle + const int wgmma_m = 64; + const int elem_size_ab = c10::elementSize(desc.a_dtype); + const int elem_size_cd = c10::elementSize(desc.cd_dtype); + DG_HOST_ASSERT(desc.a_dtype == desc.b_dtype); + + // Data movement per block + int64_t expected_k = desc.get_expected_k(); + int64_t num_bytes_l2_ab = expected_k * (layout.block_m / layout.cluster_n + layout.block_n / layout.cluster_m) * elem_size_ab; + int64_t num_bytes_l1_ab = expected_k * (layout.block_m + layout.block_n) * elem_size_ab; + int64_t num_bytes_l1_tc = expected_k * (std::max(wgmma_m, layout.block_m) + layout.block_n) * elem_size_ab + + layout.block_m * layout.block_n * elem_size_cd; + int64_t num_bytes_l1_l2_cd = layout.block_m * layout.block_n * elem_size_cd * (desc.with_accumulation ? 2 : 1); + + // HBM bandwidth and total compute (Tensor/CUDA cores) are constant across configs + // We only model L1/L2 cycles as they are the primary variables between configs + int64_t num_l2_cycles = (num_bytes_l2_ab + num_bytes_l1_l2_cd) * num_blocks / l2_bandwidth_per_cycle; + int64_t num_l1_cycles = (num_bytes_l1_ab + num_bytes_l1_tc + num_bytes_l1_l2_cd) * num_blocks / l1_bandwidth_per_cycle; + float wave_efficiency = static_cast(num_blocks) / (num_waves * desc.num_sms); + int64_t num_cycles = std::max(num_l1_cycles, num_l2_cycles) / wave_efficiency; + + // Disable multicasting if only one wave exists + if (layout.cluster_n * layout.cluster_m > 1 and num_waves <= 1) + num_cycles = std::numeric_limits::max(); + + return {num_waves, last_wave_util, num_cycles, layout}; + } + + // A regular comparator + static bool compare(const LayoutInfo& a, const LayoutInfo& b) { + return a.num_cycles < b.num_cycles; + } +}; + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/heuristics/utils.hpp b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/utils.hpp new file mode 100644 index 00000000..17d2ae07 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/heuristics/utils.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include +// Reuse some types in the JIT modules +#include + +#include "common.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +template +static int get_swizzle_mode(const int& block_size, const size_type_t& elem_size) { + // `> 0` means interleaving + // 16B actually means non-swizzling (but interleaving) + for (const int& mode: {128, 64, 32, 16}) { + if ((block_size * static_cast(elem_size)) % mode == 0) + return mode; + } + DG_HOST_UNREACHABLE("Unreachable"); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/epilogue.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/epilogue.hpp new file mode 100644 index 00000000..1003df4c --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/epilogue.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +static std::string get_default_epilogue_type(const std::optional& epilogue_type) { + return epilogue_type.value_or("epilogue::transform::EpilogueIdentity"); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/runtime_utils.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/runtime_utils.hpp new file mode 100644 index 00000000..72a76f0d --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/runtime_utils.hpp @@ -0,0 +1,267 @@ +#pragma once + +#include +#include + +#include "../heuristics/sm90.hpp" +#include "../../jit/handle.hpp" +#include "../../utils/math.hpp" +#include "../../utils/system.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +static std::pair get_inner_outer_dims(const cute::UMMA::Major& major, const int& k, const int& mn) { + return major == cute::UMMA::Major::K ? std::make_pair(k, mn) : std::make_pair(mn, k); +} + +static int get_non_contiguous_dim(const cute::UMMA::Major& major) { + return major == cute::UMMA::Major::K ? -2 : -1; +} + +static int get_compiled_dim(const int& dim, const char& name, const std::string& compiled_dims) { + if (heuristics_runtime->get_ignore_compile_dims()) + return 0; + + for (const char& c: compiled_dims) { + if (name == c) + return dim; + } + return 0; +} + +static std::string to_string(const cute::UMMA::Major& major) { + switch (major) { + case cute::UMMA::Major::K: return "cute::UMMA::Major::K"; + case cute::UMMA::Major::MN: return "cute::UMMA::Major::MN"; + } + DG_HOST_UNREACHABLE("Unknown major"); +} + +static std::string to_string(const GemmType& type) { + switch (type) { + case GemmType::Normal: return "GemmType::Normal"; + case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; + case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; + case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout"; + case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; + case GemmType::Batched: return "GemmType::Batched"; + } + DG_HOST_UNREACHABLE("Unknown GEMM type"); +} + +static std::string to_string(const at::ScalarType& dtype) { + switch (dtype) { + case torch::kInt: return "int"; + case torch::kFloat: return "float"; + case torch::kBFloat16: return "cutlass::bfloat16_t"; + case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t"; + case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t"; + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static std::string to_string(const float& v) { + if (std::isfinite(v)) { + return fmt::format(R"({:a}f)", v); + } else if (std::isinf(v)) { + return v > 0 ? "cute::numeric_limits::infinity()" + : "-cute::numeric_limits::infinity()"; + } + DG_HOST_UNREACHABLE("NaN input is not supported"); +} + +static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& dtype, + const bool& allow_tf32, + const bool& fp4_unpacked_smem) { + if (allow_tf32 and dtype == torch::kFloat) + return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; + + switch (dtype) { + case torch::kInt: return CU_TENSOR_MAP_DATA_TYPE_INT32; + case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; +#if CUDA_VERSION >= 12080 + case kPackedFP4: return fp4_unpacked_smem ? CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B + : CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; +#endif + default: DG_HOST_UNREACHABLE("Unsupported dtype"); + } +} + +static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) { +#if CUDA_VERSION >= 12080 + if (base != 0) { + DG_HOST_ASSERT(base == 32 and mode == 128); + return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; + } +#endif + + DG_HOST_ASSERT(base == 0); + switch (mode) { + case 0: + case 16: return CU_TENSOR_MAP_SWIZZLE_NONE; + case 32: return CU_TENSOR_MAP_SWIZZLE_32B; + case 64: return CU_TENSOR_MAP_SWIZZLE_64B; + case 128: return CU_TENSOR_MAP_SWIZZLE_128B; + default: DG_HOST_UNREACHABLE("Unsupported swizzling mode"); + } +} + +static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, + int gmem_inner_dim, int gmem_outer_dim, + int smem_inner_dim, int smem_outer_dim, + const int& gmem_outer_stride, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false, + const bool& fp4_unpacked_smem = true) { + const auto elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + smem_inner_dim = swizzle_mode / elem_size; + + if (t.scalar_type() == kPackedFP4) { + // Inner dim must be a multiple of 64B for .b4x16_p64 + DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_inner_dim % 128 == 0); + + // Fix FP4 packed smem + if (not fp4_unpacked_smem and swizzle_mode != 0) + smem_inner_dim = swizzle_mode * 2; + } + + CUtensorMap tensor_map; + const cuuint64_t gmem_dims[2] = {static_cast(gmem_inner_dim), static_cast(gmem_outer_dim)}; + const cuuint32_t smem_dims[2] = {static_cast(smem_inner_dim), static_cast(smem_outer_dim)}; + const cuuint64_t gmem_strides[1] = {static_cast(gmem_outer_stride * elem_size), }; + const cuuint32_t elem_strides[2] = {1, 1}; + if (get_env("DG_JIT_DEBUG")) { + printf("Making TMA desc: global memory: %d %d, shared memory: %d %d, outer stride: %d, swizzle: %d (base: %d), elem size: %d, pointer: %llu\n", + gmem_inner_dim, gmem_outer_dim, smem_inner_dim, smem_outer_dim, + gmem_outer_stride, swizzle_mode, swizzle_base, elem_size, + reinterpret_cast(t.data_ptr())); + } + DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem), + 2, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensor_map; +} + +static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, + int gmem_dim_0, int gmem_dim_1, int gmem_dim_2, + int smem_dim_0, int smem_dim_1, int smem_dim_2, + const int& gmem_stride_0, const int& gmem_stride_1, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false, + const bool& fp4_unpacked_smem = true) { + const auto elem_size = static_cast(t.element_size()); + if (swizzle_mode != 0) + smem_dim_0 = swizzle_mode / elem_size; + + if (t.scalar_type() == kPackedFP4) { + // Inner dim must be a multiple of 64B for .b4x16_p64 + DG_HOST_ASSERT(not fp4_unpacked_smem or gmem_dim_0 % 128 == 0); + + // Fix fp4 packed smem + if (not fp4_unpacked_smem and swizzle_mode != 0) + smem_dim_0 = swizzle_mode * 2; + } + + CUtensorMap tensor_map; + const cuuint64_t gmem_dims[3] = {static_cast(gmem_dim_0), static_cast(gmem_dim_1), static_cast(gmem_dim_2),}; + const cuuint32_t smem_dims[3] = {static_cast(smem_dim_0), static_cast(smem_dim_1), static_cast(smem_dim_2)}; + const cuuint64_t gmem_strides[2] = {static_cast(gmem_stride_0 * elem_size), static_cast(gmem_stride_1 * elem_size)}; + const cuuint32_t elem_strides[3] = {1, 1, 1}; + if (get_env("DG_JIT_DEBUG")) { + printf("Making 3D TMA desc: global memory: %d %d %d, shared memory: %d %d %d, outer stride: %d %d, swizzle: %d, elem size: %d\n", + gmem_dim_0, gmem_dim_1, gmem_dim_2, smem_dim_0, smem_dim_1, smem_dim_2, + gmem_stride_0, gmem_stride_1, swizzle_mode, elem_size); + } + DG_CUDA_DRIVER_CHECK(lazy_cuTensorMapEncodeTiled( + &tensor_map, aten_dtype_to_tensor_map_dtype(t.scalar_type(), allow_tf32, fp4_unpacked_smem), + 3, t.data_ptr(), gmem_dims, gmem_strides, smem_dims, elem_strides, + CU_TENSOR_MAP_INTERLEAVE_NONE, mode_into_tensor_map_swizzle(swizzle_mode, swizzle_base), + CU_TENSOR_MAP_L2_PROMOTION_L2_256B, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensor_map; +} + +static CUtensorMap make_tma_a_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_m, const int& shape_k, + const int& block_m, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + if (num_groups > 1) + DG_HOST_ASSERT(major == cute::UMMA::Major::K); + const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_m * num_groups); + const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_m); + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode, swizzle_base, + allow_tf32); +} + +static CUtensorMap make_tma_b_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + const int& shape_n, const int& shape_k, + const int& block_n, const int& block_k, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + const auto [gmem_inner_dim, gmem_outer_dim] = get_inner_outer_dims(major, shape_k, shape_n); + const auto [smem_inner_dim, smem_outer_dim] = get_inner_outer_dims(major, block_k, block_n); + + // `num_groups` is always applied into the outer dimensions + return make_tma_2d_desc(t, + gmem_inner_dim, gmem_outer_dim * num_groups, + smem_inner_dim, smem_outer_dim, + outer_stride, + swizzle_mode, swizzle_base, + allow_tf32); +} + +static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, + const int& shape_m, const int& shape_n, + const int& block_m, const int& block_n, + const int& outer_stride, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + // Swizzling requires the inner box dim to be less or equal than `kSwizzleCDMode` + // bytes, so `BLOCK_N * sizeof(T) / kSwizzleCDMode` TMA stores are required + return make_tma_2d_desc(t, + shape_n, shape_m * num_groups, + block_n, block_m, + outer_stride, + swizzle_mode, swizzle_base, + allow_tf32); +} + +static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, + const torch::Tensor& t, + int shape_mn, int shape_k, + const int& block_mn, const int& gran_k, + const int& num_groups, + const int& swizzle_mode, const int& swizzle_base = 0, + const bool& allow_tf32 = false) { + DG_HOST_ASSERT(major == cute::UMMA::Major::MN); + + // TODO: maybe swizzle SF as well + DG_HOST_ASSERT(swizzle_mode == 0); + + shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); + return make_tma_2d_desc(t, + shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, + block_mn, 1, + shape_mn, + swizzle_mode, swizzle_base, + allow_tf32); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp new file mode 100644 index 00000000..26219b0c --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -0,0 +1,415 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_bf16_gemm_impl< + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b), + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_desc.num_groups, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, + args.gemm_config.layout.swap_ab, + to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, to_string(args.gemm_desc.cd_dtype), + args.gemm_desc.tc_util); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_cd)); + } +}; + +static void sm100_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_gemm", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_contiguous", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_m_grouped_gemm_masked", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0; + for (const auto k: ks) { + sum_k += k; + DG_HOST_ASSERT(k % 128 == 0); + } + const auto num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto max_k = *std::max_element(ks.begin(), ks.end()); + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(0)), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(0)), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(1)), num_groups, + config.storage_config.swizzle_cd_mode); + + // Launch kernel + const SM100BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_k_grouped_gemm", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = d, .k = r, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.layout.block_k, config.storage_config.load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.layout.block_k, config.storage_config.load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + config.storage_config.store_block_n, config.storage_config.store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.storage_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_bhr_hdr_bhd", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = r, .k = d, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.layout.block_k, config.storage_config.load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.storage_config.load_block_n, config.layout.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + config.storage_config.store_block_n, config.storage_config.store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.storage_config.swizzle_cd_mode); + + // Launch + const SM100BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_bf16_bhd_hdr_bhr", code); + SM100BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp new file mode 100644 index 00000000..65c9d501 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BmkBnkMnRuntime final: public LaunchRuntime { +public: + struct Args { + int s, m, n, k; + int block_m, block_n, block_k; + int split_factor; + int swizzle_ab_mode, swizzle_cd_mode; + int num_stages; + int num_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_bmn_bnk_mn_gemm_impl< + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, + {}, {} + >); +}}; +)", + args.m, args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.split_factor, + args.swizzle_ab_mode, args.swizzle_cd_mode, + args.num_stages, args.num_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.s, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d)); + } +}; + + +static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a, + const torch::Tensor &b, + const torch::Tensor &d, + const int &s, const int &m, const int &n, const int &k) { + constexpr int block_m = 128; + constexpr int block_n = 128; + constexpr int block_k = 64; + constexpr int num_threads = 128; + DG_HOST_ASSERT(k % block_k == 0); + DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0); + DG_HOST_ASSERT(static_cast(s) * static_cast(std::max(m, n)) <= std::numeric_limits::max()); + + const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast(a.element_size())); + const int swizzle_cd_mode = get_swizzle_mode(block_n, static_cast(d.element_size())); + + // Get best config + const int num_sms = device_runtime->get_num_sms(); + const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n); + const int num_sk_blocks = s * (k / block_k); + const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1)); + + // Select best number of stages + // NOTES: we select 4 as start, as it is tested to be faster than values > 4 + int num_stages = 4, smem_size = 0; + while (true) { + const int smem_cd = block_m * swizzle_cd_mode * 2; + const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int smem_barrier = num_stages * 8 * 3 + 2 * 8 * 2 + 8; + const int smem_tmem_ptr = 4; + + smem_size = 0; + smem_size += smem_cd; + smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages; + smem_size += smem_barrier; + smem_size += smem_tmem_ptr; + if (smem_size <= SM100ArchSpec::smem_capacity) + break; + + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("S: %d, M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split-K factor: %d" + "stages: %d, shared memory: %d, swizzle AB: %d, swizzle CD: %d\n", + s, m, n, k, block_m, block_n, block_k, split_factor, + num_stages, smem_size, swizzle_ab_mode, swizzle_cd_mode); + } + + const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + const auto tensor_map_d = make_tma_2d_desc(d, n, m, block_n, block_m, n, swizzle_cd_mode); + + const SM100BmkBnkMnRuntime::Args args = { + .s = s, .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .split_factor = split_factor, + .swizzle_ab_mode = swizzle_ab_mode, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_threads = num_threads, + .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d + }; + const auto code = SM100BmkBnkMnRuntime::generate(args); + const auto runtime = compiler->build("sm100_bmn_bnk_mn_gemm", code); + SM100BmkBnkMnRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp new file mode 100644 index 00000000..b8826361 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_fp4_gemm_1d1d.hpp @@ -0,0 +1,459 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" + +#include "epilogue.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + // TODO: move into descriptor + const std::optional epilogue_type; + + // TODO: move into descriptor + int gran_k_a, gran_k_b; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + // TODO: rename files + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_fp4_gemm_1d1d_impl< + {}, {}, + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, + {}, {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b), + args.gran_k_a, args.gran_k_b, + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_desc.num_groups, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_non_epilogue_threads, args.gemm_config.launch_config.num_epilogue_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, + args.gemm_config.layout.swap_ab, + to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, + to_string(args.gemm_desc.a_dtype), to_string(args.gemm_desc.b_dtype), to_string(args.gemm_desc.cd_dtype), + get_default_epilogue_type(args.epilogue_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_cd)); + } +}; + +static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const auto cd = c.value_or(d); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, 1, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = epilogue_type, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, num_groups, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const int& gran_k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); + const int gran_k_a = gran_k; + const int gran_k_b = gran_k; + + int sum_k = 0, sum_sf_k = 0; + for (const auto k: ks) { + sum_k += k, sum_sf_k += ceil_div(k, gran_k * 4); + DG_HOST_ASSERT(k % gran_k == 0); + } + const auto num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto max_k = *std::max_element(ks.begin(), ks.end()); + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(0)), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(0)), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(1)), num_groups, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * gran_k_a * 4, + config.layout.block_m, gran_k_a, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * gran_k_b * 4, + config.layout.block_n, gran_k_b, 1, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = batch_size, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), + .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const int load_block_m = config.storage_config.load_block_m; + const auto [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m); + const auto [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.layout.block_k, load_block_m); + const auto tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size, + inner_block_a, outer_block_a, 1, + a.stride(major_a == cute::UMMA::Major::K ? 1 : 2), + a.stride(0), + config.storage_config.swizzle_a_mode); + + const int load_block_n = config.storage_config.load_block_n; + const auto [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n); + const auto [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.layout.block_k, load_block_n); + const auto tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size, + inner_block_b, outer_block_b, 1, + b.stride(major_b == cute::UMMA::Major::K ? 1 : 2), + b.stride(0), + config.storage_config.swizzle_b_mode); + + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.storage_config.swizzle_cd_mode); + + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, gran_k_a, batch_size, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, gran_k_b, batch_size, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp new file mode 100644 index 00000000..4d912569 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_fp4_mega_moe.hpp @@ -0,0 +1,220 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "runtime_utils.hpp" + +#include +#include + +#include "../heuristics/mega_moe.hpp" + +namespace deep_gemm { + +class SM100FP8FP4MegaMoERuntime final : public LaunchRuntime { +public: + struct Args { + // Templated arguments + int num_max_tokens_per_rank; + int hidden, intermediate_hidden; + int num_experts, num_topk; + int num_ranks; + float activation_clamp; + bool fast_math; + MegaMoEConfig config; + + // Runtime arguments + void* y; + int* cumulative_local_expert_recv_stats; + int num_tokens; + layout::SymBuffer<> sym_buffer_ptrs; + + // Tensormap + CUtensorMap tensor_map_l1_acts; + CUtensorMap tensor_map_l1_acts_sf; + CUtensorMap tensor_map_l1_weights; + CUtensorMap tensor_map_l1_weights_sf; + CUtensorMap tensor_map_l1_output; + CUtensorMap tensor_map_l2_acts; + CUtensorMap tensor_map_l2_acts_sf; + CUtensorMap tensor_map_l2_weights; + CUtensorMap tensor_map_l2_weights_sf; + + // Launch configs + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_fp4_mega_moe_impl< + {}, + {}, {}, + {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, + {}, + {}, + {}, {}, {}, + {}, {}, + {}, + {} + >); +}}; +)", args.num_max_tokens_per_rank, + args.hidden, args.intermediate_hidden, + args.num_experts, args.num_topk, + args.config.num_experts_per_wave, + args.config.block_m, args.config.block_n, args.config.block_k, + args.config.store_block_m, + args.config.sf_block_m, args.config.sf_block_n, + args.config.num_max_pool_tokens, + args.config.num_padded_sf_pool_tokens, + args.config.num_stages, + args.config.num_dispatch_threads, args.config.num_non_epilogue_threads, args.config.num_epilogue_threads, + args.launch_args.grid_dim.first, args.num_ranks, + to_string(args.activation_clamp), + args.fast_math ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.y, + args.cumulative_local_expert_recv_stats, + args.num_tokens, + args.sym_buffer_ptrs, + args.tensor_map_l1_acts, + args.tensor_map_l1_acts_sf, + args.tensor_map_l1_weights, + args.tensor_map_l1_weights_sf, + args.tensor_map_l1_output, + args.tensor_map_l2_acts, + args.tensor_map_l2_acts_sf, + args.tensor_map_l2_weights, + args.tensor_map_l2_weights_sf + )); + } +}; + +static void sm100_fp8_fp4_mega_moe( + const torch::Tensor& y, + const torch::Tensor& l1_acts, const torch::Tensor& l1_acts_sf, + const torch::Tensor& l2_acts, const torch::Tensor& l2_acts_sf, + const torch::Tensor& l1_weights, const torch::Tensor& l2_weights, + const torch::Tensor& l1_weights_sf, const torch::Tensor& l2_weights_sf, + const std::optional cumulative_local_expert_recv_stats, + const std::vector& sym_buffer_ptrs, + const int& rank_idx, const int& num_max_tokens_per_rank, + const int& num_experts_per_rank, + const int& num_tokens, const int& num_topk, + const int& hidden, const int& intermediate_hidden, + const float& activation_clamp, + const bool& fast_math +) { + const auto num_ranks = static_cast(sym_buffer_ptrs.size()); + const auto num_experts = num_experts_per_rank * num_ranks; + const auto num_padded_sf_pool_tokens = static_cast(l1_acts_sf.size(0)); + + // Heuristics + const auto config = get_mega_moe_config( + num_ranks, num_experts, num_experts_per_rank, + num_max_tokens_per_rank, num_tokens, num_topk, hidden, intermediate_hidden, num_padded_sf_pool_tokens); + + // Make tensormap + constexpr int kGranK = 32; + const auto tensor_map_l1_acts = make_tma_2d_desc(l1_acts, + hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, + static_cast(l1_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l1_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_acts_sf, + config.num_padded_sf_pool_tokens, hidden, + config.sf_block_m, kGranK, + 1, 0); + const auto tensor_map_l1_weights = make_tma_2d_desc(l1_weights, + hidden, num_experts_per_rank * intermediate_hidden * 2, + config.block_k, config.load_block_n, + static_cast(l1_weights.stride(-2)), + config.swizzle_weights_mode); + const auto tensor_map_l1_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l1_weights_sf, + intermediate_hidden * 2, hidden, + config.block_n, kGranK, + num_experts_per_rank, 0); + // NOTES: L1 output and L2 activations are essentially the same tensor. + // Post-SwiGLU output has half the N width (`BLOCK_N / 2` per input tile), + // so the swizzle mode is also halved (128 -> 64). + const auto tensor_map_l1_output = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_n / 2, config.store_block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode / 2); + const auto tensor_map_l2_acts = make_tma_2d_desc(l2_acts, + intermediate_hidden, config.num_max_pool_tokens, + config.block_k, config.load_block_m, + static_cast(l2_acts.stride(-2)), + config.swizzle_acts_mode); + const auto tensor_map_l2_acts_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_acts_sf, + config.num_padded_sf_pool_tokens, intermediate_hidden, + config.sf_block_m, kGranK, + 1, 0); + const auto tensor_map_l2_weights = make_tma_2d_desc(l2_weights, + intermediate_hidden, num_experts_per_rank * hidden, + config.block_k, config.load_block_n, + static_cast(l2_weights.stride(-2)), + config.swizzle_weights_mode); + const auto tensor_map_l2_weights_sf = make_tma_sf_desc(cute::UMMA::Major::MN, l2_weights_sf, + hidden, intermediate_hidden, + config.block_n, kGranK, + num_experts_per_rank, 0); + + // Stats can be optional + int* cumulative_local_expert_recv_stats_ptr = nullptr; + if (cumulative_local_expert_recv_stats.has_value()) + cumulative_local_expert_recv_stats_ptr = cumulative_local_expert_recv_stats->data_ptr(); + + // Launch + const auto num_sms = device_runtime->get_num_sms(); + const SM100FP8FP4MegaMoERuntime::Args args = { + .num_max_tokens_per_rank = num_max_tokens_per_rank, + .hidden = hidden, .intermediate_hidden = intermediate_hidden, + .num_experts = num_experts, .num_topk = num_topk, + .num_ranks = num_ranks, + .activation_clamp = activation_clamp, + .fast_math = fast_math, + .config = config, + .y = y.data_ptr(), + .cumulative_local_expert_recv_stats = cumulative_local_expert_recv_stats_ptr, + .num_tokens = num_tokens, + .sym_buffer_ptrs = layout::SymBuffer<>(sym_buffer_ptrs, rank_idx), + .tensor_map_l1_acts = tensor_map_l1_acts, + .tensor_map_l1_acts_sf = tensor_map_l1_acts_sf, + .tensor_map_l1_weights = tensor_map_l1_weights, + .tensor_map_l1_weights_sf = tensor_map_l1_weights_sf, + .tensor_map_l1_output = tensor_map_l1_output, + .tensor_map_l2_acts = tensor_map_l2_acts, + .tensor_map_l2_acts_sf = tensor_map_l2_acts_sf, + .tensor_map_l2_weights = tensor_map_l2_weights, + .tensor_map_l2_weights_sf = tensor_map_l2_weights_sf, + .launch_args = LaunchArgs(num_sms, + config.num_dispatch_threads + config.num_non_epilogue_threads + config.num_epilogue_threads, + config.smem_size, 2) + }; + + const auto code = SM100FP8FP4MegaMoERuntime::generate(args); + const auto runtime = compiler->build("sm100_fp8_fp4_mega_moe", code); + SM100FP8FP4MegaMoERuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp new file mode 100644 index 00000000..07a977d7 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -0,0 +1,416 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" + +#include "epilogue.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k, num_groups; + int gran_k_a, gran_k_b; + const std::string& compiled_dims; + const std::optional& epilogue_type; + + GemmConfig gemm_config; + LaunchArgs launch_args; + + void* grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d1d_impl< + {}, {}, + {}, {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {}, {}, {}, + {} + >); +}}; +)", + to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + args.gran_k_a, args.gran_k_b, + get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), + args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, + args.num_groups, + args.gemm_config.smem_config.swizzle_a_mode, args.gemm_config.smem_config.swizzle_b_mode, args.gemm_config.smem_config.swizzle_cd_mode, + args.gemm_config.num_stages, + args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, + args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, + args.gemm_config.num_sms, + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, + to_string(args.gemm_config.a_dtype), to_string(args.gemm_config.b_dtype), to_string(args.gemm_config.cd_dtype), + get_default_epilogue_type(args.epilogue_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, args.m, args.n, args.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_cd)); + } +}; + +static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { + const auto& config = get_best_config( + GemmType::Normal, KernelType::Kernel1D1D, + m, n, k, 1, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const auto& cd = c.value_or(d); + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, 1, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = 1, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = epilogue_type, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; + const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; + + const auto& config = get_best_config( + gemm_type, KernelType::Kernel1D1D, + m_for_config, n, k, num_groups_for_config, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), 1, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = grouped_layout.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::MGroupedMasked, KernelType::Kernel1D1D, + expected_m, n, k, num_groups, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(major_b, b, n, k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(-2)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, gran_k_a, num_groups, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, gran_k_b, num_groups, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0, sum_sf_k = 0; + for (const auto& k: ks) { + sum_k += k, sum_sf_k += ceil_div(k, 512); + DG_HOST_ASSERT(k % 128 == 0); + } + const auto& num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto& max_k = *std::max_element(ks.begin(), ks.end()); + const auto& config = get_best_config( + GemmType::KGroupedContiguous, KernelType::Kernel1D1D, + m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Create tensor descriptors + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m), + config.block_k, + static_cast(a.stride(0)), 1, + config.smem_config.swizzle_a_mode); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n), + config.block_k, + static_cast(b.stride(0)), 1, + config.smem_config.swizzle_b_mode); + const auto& tensor_map_cd = make_tma_cd_desc(d, m, n, + SM100ArchSpec::get_cd_store_block_m(config.block_m), + SM100ArchSpec::get_cd_store_block_n(config.block_n), + static_cast(d.stride(1)), num_groups, + config.smem_config.swizzle_cd_mode); + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 512, + config.block_m, config.block_k, 1, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 512, + config.block_n, config.block_k, 1, 0); + + // Launch kernel + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = sum_k, + .num_groups = num_groups, + .gran_k_a = 128, + .gran_k_b = 128, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto& config = get_best_config( + GemmType::Batched, KernelType::Kernel1D1D, + m, n, k, batch_size, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& [inner_dim_a, outer_dim_a] = get_inner_outer_dims(major_a, k, m); + const auto& [inner_block_a, outer_block_a] = get_inner_outer_dims(major_a, config.block_k, load_block_m); + const auto& tensor_map_a = make_tma_3d_desc(a, inner_dim_a, outer_dim_a, batch_size, + inner_block_a, outer_block_a, 1, + a.stride(major_a == cute::UMMA::Major::K ? 1 : 2), + a.stride(0), + config.smem_config.swizzle_a_mode); + + const int& load_block_n = SM100ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& [inner_dim_b, outer_dim_b] = get_inner_outer_dims(major_b, k, n); + const auto& [inner_block_b, outer_block_b] = get_inner_outer_dims(major_b, config.block_k, load_block_n); + const auto& tensor_map_b = make_tma_3d_desc(b, inner_dim_b, outer_dim_b, batch_size, + inner_block_b, outer_block_b, 1, + b.stride(major_b == cute::UMMA::Major::K ? 1 : 2), + b.stride(0), + config.smem_config.swizzle_b_mode); + + const int& store_block_m = SM100ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM100ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_cd = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.smem_config.swizzle_cd_mode); + + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, batch_size, 0); + const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.block_n, config.block_k, batch_size, 0); + + // Launch + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, + .num_groups = batch_size, + .gran_k_a = 128, + .gran_k_b = 128, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd + }; + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..0071e2c5 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,149 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_mma_threads, num_cast_and_reduce_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_mma_threads, args.num_cast_and_reduce_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_mma_threads = 128; + constexpr int num_cast_and_reduce_threads = 128; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + DG_HOST_ASSERT(n <= 128 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = (num_stages * 4 + 1) * 8; + const int smem_tmem_ptr = 4; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers + smem_tmem_ptr; + + if (smem_size <= SM100ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + // Launch + const SM100BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_mma_threads = num_mma_threads, + .num_cast_and_reduce_threads = num_cast_and_reduce_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto code = SM100BF16HCPrenormGemmRuntime::generate(args); + const auto runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code); + SM100BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp new file mode 100644 index 00000000..1d29d855 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -0,0 +1,432 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16GemmRuntime final: public LaunchRuntime { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_bf16_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", + // TODO: add CD dtype + to_string(args.gemm_desc.major_a), to_string(args.gemm_desc.major_b), + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_desc.num_groups, + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_config.storage_config.swizzle_a_mode, + args.gemm_config.storage_config.swizzle_b_mode, + args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads, + // TODO: refactor with cluster M/N + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, + to_string(args.gemm_desc.gemm_type), args.gemm_desc.with_accumulation, + to_string(args.gemm_desc.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.grouped_layout, + args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_cd)); + } +}; + +static void sm90_bf16_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_gemm", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); + + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + DG_HOST_ASSERT(k % 64 == 0); + + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = 0, .expected_k = 0, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); + + int sum_k = 0; + for (const auto k: ks) { + sum_k += k; + DG_HOST_ASSERT(k % 128 == 0); + } + const auto num_groups = static_cast(ks.size()); + + // Get config using max K for better performance + const auto max_k = *std::max_element(ks.begin(), ks.end()); + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::KernelNoSF, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Create tensor descriptors + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::MN, a, m, sum_k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(0)), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::MN, b, n, sum_k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(0)), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(1)), num_groups, + config.storage_config.swizzle_cd_mode); + + // Launch kernel + const SM90BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_k_grouped_gemm", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = d, .k = r, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::K, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const int load_block_m = config.storage_config.load_block_m; + const auto tensor_map_a = make_tma_3d_desc(tensor_a, r, b, h, + config.layout.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const int load_block_n = config.storage_config.load_block_n; + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + config.layout.block_k, load_block_n, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, d, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.storage_config.swizzle_cd_mode); + + // Launch + const SM90BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_bhr_hdr_bhd", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, + const torch::Tensor& tensor_b, + const torch::Tensor& tensor_d, + const int& b, const int& h, const int& r, const int& d, + const std::string& compiled_dims = "nk") { + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::KernelNoSF, + .m = b, .n = r, .k = d, .num_groups = h, + .a_dtype = tensor_a.scalar_type(), .b_dtype = tensor_b.scalar_type(), + .cd_dtype = tensor_d.scalar_type(), + .major_a = cute::UMMA::Major::K, .major_b = cute::UMMA::Major::MN, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + const int load_block_m = config.storage_config.load_block_m; + const auto tensor_map_a = make_tma_3d_desc(tensor_a, d, b, h, + config.layout.block_k, load_block_m, 1, + tensor_a.stride(0), tensor_a.stride(1), + config.storage_config.swizzle_a_mode); + const int load_block_n = config.storage_config.load_block_n; + const auto tensor_map_b = make_tma_3d_desc(tensor_b, r, d, h, + load_block_n, config.layout.block_k, 1, + tensor_b.stride(1), tensor_b.stride(0), + config.storage_config.swizzle_b_mode); + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_cd = make_tma_3d_desc(tensor_d, r, b, h, + store_block_n, store_block_m, 1, + tensor_d.stride(0), tensor_d.stride(1), + config.storage_config.swizzle_cd_mode); + // Launch + const SM90BF16GemmRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90BF16GemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_bf16_bhd_hdr_bhr", code); + SM90BF16GemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp new file mode 100644 index 00000000..473677b7 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_bmk_bnk_mn.hpp @@ -0,0 +1,131 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BmkBnkMnRuntime final: public LaunchRuntime { +public: + struct Args { + int s, m, n, k; + int block_m, block_n, block_k; + int split_factor; + int num_stages; + int num_tma_threads, num_math_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + float* d; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_bmn_bnk_mn_gemm_impl< + {}, {}, {}, + {}, {}, {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.m, args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.split_factor, + args.num_stages, + args.num_tma_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.s, args.tensor_map_a, args.tensor_map_b, args.d)); + } +}; + + +static void sm90_bmn_bnk_mn_gemm(const torch::Tensor &a, + const torch::Tensor &b, + const torch::Tensor &d, + const int &s, const int &m, const int &n, const int &k) { + constexpr int block_m = 128; + constexpr int block_n = 128; + constexpr int block_k = 64; + constexpr int num_tma_threads = 128; + constexpr int num_math_threads = 256; + DG_HOST_ASSERT(k % block_k == 0); + DG_HOST_ASSERT(m % 64 == 0 and n % 64 == 0); + DG_HOST_ASSERT(static_cast(s) * static_cast(std::max(m, n)) <= std::numeric_limits::max()); + + const int swizzle_ab_mode = get_swizzle_mode(block_k, static_cast(a.element_size())); + DG_HOST_ASSERT(swizzle_ab_mode == 128); + + // Get best config + const int num_sms = device_runtime->get_num_sms(); + const int num_mn_blocks = ceil_div(m, block_m) * ceil_div(n, block_n); + const int num_sk_blocks = s * (k / block_k); + const int split_factor = ceil_div(num_sk_blocks, std::max(num_sms / num_mn_blocks, 1)); + + // Select best number of stages + int num_stages = 4, smem_size = 0; + while (true) { + const int smem_a_per_stage = block_m * block_k * sizeof(cutlass::bfloat16_t); + const int smem_b_per_stage = block_n * block_k * sizeof(cutlass::bfloat16_t); + const int smem_barrier = num_stages * 8 * 2; + + smem_size = 0; + smem_size += (smem_a_per_stage + smem_b_per_stage) * num_stages; + smem_size += smem_barrier; + + if (smem_size <= SM90ArchSpec::smem_capacity) + break; + + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("S: %d, M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split-K factor: %d" + "stages: %d, shared memory: %d, swizzle AB: %d\n", + s, m, n, k, block_m, block_n, block_k, split_factor, + num_stages, smem_size, swizzle_ab_mode); + } + + const auto tensor_map_a = make_tma_2d_desc(a, k, s * m, block_k, block_m, k, swizzle_ab_mode); + const auto tensor_map_b = make_tma_2d_desc(b, k, s * n, block_k, block_n, k, swizzle_ab_mode); + + const SM90BmkBnkMnRuntime::Args& args = { + .s = s, .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .split_factor = split_factor, + .num_stages = num_stages, + .num_tma_threads = num_tma_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_mn_blocks * ceil_div(num_sk_blocks, split_factor), num_tma_threads + num_math_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .d = d.data_ptr() + }; + const auto code = SM90BmkBnkMnRuntime::generate(args); + const auto runtime = compiler->build("sm90_bmn_bnk_mn_gemm", code); + SM90BmkBnkMnRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp new file mode 100644 index 00000000..9d903d48 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -0,0 +1,229 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8Gemm1D1DRuntime final: public LaunchRuntime { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + + void *gmem_a_ptr; + void *gmem_b_ptr; + void *grouped_layout; + void *tensor_map_buffer; + CUtensorMap tensor_map_a_base; + CUtensorMap tensor_map_b_base; + CUtensorMap tensor_map_sfa; + CUtensorMap tensor_map_sfb; + CUtensorMap tensor_map_cd; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d1d_impl< + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_desc.num_groups, + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type), + to_string(args.gemm_desc.cd_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.gmem_a_ptr, args.gmem_b_ptr, + args.grouped_layout, + args.tensor_map_buffer, + args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, + args.tensor_map_a_base, args.tensor_map_b_base, + args.tensor_map_sfa, args.tensor_map_sfb, + args.tensor_map_cd)); + } +}; + +static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, k, 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, k, 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, + config.layout.block_n, config.layout.block_k, 1, 0); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + 0); + + // Launch + const SM90FP8Gemm1D1DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .gmem_a_ptr = nullptr, + .gmem_b_ptr = nullptr, + .grouped_layout = nullptr, + .tensor_map_buffer = nullptr, + .tensor_map_a_base = tensor_map_a, + .tensor_map_b_base = tensor_map_b, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + + SM90FP8Gemm1D1DRuntime::launch(runtime, args); +} + +static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const torch::Tensor& tensor_map_buffer, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { + DG_HOST_ASSERT(c.has_value() and d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + // TODO: refactor with the mk alignment function + const auto num_groups = static_cast(ks.size()); + int first_k = 0, sum_k = 0, sum_sf_k = 0, max_k = 0; + for (int i = 0; i < num_groups; ++ i) { + if (first_k == 0 and ks[i] != 0) + first_k = ks[i]; + sum_k += ks[i], sum_sf_k += ceil_div(ks[i], 128); + max_k = std::max(max_k, ks[i]); + DG_HOST_ASSERT(ks[i] % 128 == 0); + } + + // Get config using max K for better performance + const auto desc = GemmDesc { + .gemm_type = GemmType::KGroupedContiguous, + .kernel_type = KernelType::Kernel1D1D, + .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = m, .expected_n = n, .expected_k = max_k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + + const auto tensor_map_a_base = make_tma_a_desc(major_a, a, m, first_k, + config.storage_config.load_block_m, + config.layout.block_k, first_k, 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b_base = make_tma_b_desc(major_b, b, n, first_k, + config.storage_config.load_block_n, + config.layout.block_k, first_k, 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, sum_sf_k * 128, + config.layout.block_m, config.layout.block_k, 1, 0); + const auto tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, sum_sf_k * 128, + config.layout.block_n, config.layout.block_k, 1, 0); + const auto tensor_map_cd = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + + // Launch + const SM90FP8Gemm1D1DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .gmem_a_ptr = a.data_ptr(), + .gmem_b_ptr = b.data_ptr(), + .grouped_layout = ks_tensor.data_ptr(), + .tensor_map_buffer = tensor_map_buffer.data_ptr(), + .tensor_map_a_base = tensor_map_a_base, + .tensor_map_b_base = tensor_map_b_base, + .tensor_map_sfa = tensor_map_sfa, + .tensor_map_sfb = tensor_map_sfb, + .tensor_map_cd = tensor_map_cd, + }; + const auto code = SM90FP8Gemm1D1DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d1d", code); + + SM90FP8Gemm1D1DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp new file mode 100644 index 00000000..96b5cd0b --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -0,0 +1,361 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../heuristics/sm90.hpp" + +#include "epilogue.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90FP8Gemm1D2DRuntime final: public LaunchRuntime { +public: + struct Args { + GemmDesc gemm_desc; + GemmConfig gemm_config; + LaunchArgs launch_args; + // TODO: move this into `gemm_desc` + const std::optional& epilogue_type; + + cute::UMMA::Major major_sfb; + void *sfb, *grouped_layout; + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + CUtensorMap tensor_map_sfa; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_fp8_gemm_1d2d_impl< + {}, + {}, {}, {}, + {}, + {}, {}, {}, + {}, {}, {}, + {}, + {}, {}, + {}, {}, + {}, {}, + {} + >); +}}; +)", + // TODO: add CD dtype + to_string(args.major_sfb), + get_compiled_dim(args.gemm_desc.m, 'm', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.n, 'n', args.gemm_desc.compiled_dims), + get_compiled_dim(args.gemm_desc.k, 'k', args.gemm_desc.compiled_dims), + args.gemm_desc.num_groups, + args.gemm_config.layout.block_m, args.gemm_config.layout.block_n, args.gemm_config.layout.block_k, + args.gemm_config.storage_config.swizzle_a_mode, args.gemm_config.storage_config.swizzle_b_mode, args.gemm_config.storage_config.swizzle_cd_mode, + args.gemm_config.pipeline_config.num_stages, + args.gemm_config.launch_config.num_tma_threads, args.gemm_config.launch_config.num_math_threads, + args.gemm_config.layout.get_cluster_size(), args.gemm_config.layout.cluster_n > 1, + args.gemm_config.launch_config.num_sms, to_string(args.gemm_desc.gemm_type), + get_default_epilogue_type(args.epilogue_type)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sfb, args.grouped_layout, + args.gemm_desc.m, args.gemm_desc.n, args.gemm_desc.k, + args.tensor_map_a, args.tensor_map_b, + args.tensor_map_d, args.tensor_map_sfa)); + } +}; + +static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { + DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto desc = GemmDesc { + .gemm_type = GemmType::Normal, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = 1, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), 1, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, static_cast(d.size(-1)), + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = epilogue_type, + .major_sfb = major_sfb, + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& m_indices, + const int& num_groups, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto gemm_type = use_psum_layout ? + GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // Only psum layout can use expected m + if (expected_m_for_psum_layout) + DG_HOST_ASSERT(use_psum_layout); + + const auto desc = GemmDesc { + .gemm_type = gemm_type, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m_for_psum_layout.value_or(m), + .expected_n = n, .expected_k = k, + .expected_num_groups = expected_m_for_psum_layout.has_value() ? num_groups : 1 + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), 1, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), 1, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, 1, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .major_sfb = major_sfb, + .sfb = sfb.data_ptr(), + .grouped_layout = m_indices.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto desc = GemmDesc { + .gemm_type = GemmType::MGroupedMasked, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = num_groups, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = false, + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims, + .expected_m = expected_m, .expected_n = n, .expected_k = k, .expected_num_groups = num_groups + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const auto tensor_map_a = make_tma_a_desc(major_a, a, m, k, + config.storage_config.load_block_m, + config.layout.block_k, + static_cast(a.stride(get_non_contiguous_dim(major_a))), num_groups, + config.storage_config.swizzle_a_mode); + const auto tensor_map_b = make_tma_b_desc(major_b, b, n, k, + config.storage_config.load_block_n, + config.layout.block_k, + static_cast(b.stride(get_non_contiguous_dim(major_b))), num_groups, + config.storage_config.swizzle_b_mode); + const auto tensor_map_d = make_tma_cd_desc(d, m, n, + config.storage_config.store_block_m, + config.storage_config.store_block_n, + static_cast(d.stride(-2)), num_groups, + config.storage_config.swizzle_cd_mode); + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, num_groups, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .major_sfb = major_sfb, + .sfb = sfb.data_ptr(), + .grouped_layout = masked_m.data_ptr(), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto desc = GemmDesc { + .gemm_type = GemmType::Batched, + .kernel_type = KernelType::Kernel1D2D, + .m = m, .n = n, .k = k, .num_groups = batch_size, + .a_dtype = a.scalar_type(), .b_dtype = b.scalar_type(), + .cd_dtype = d.scalar_type(), + .major_a = major_a, .major_b = major_b, + .with_accumulation = c.has_value(), + .num_sms = device_runtime->get_num_sms(), + .tc_util = device_runtime->get_tc_util(), .compiled_dims = compiled_dims + }; + const auto config = get_best_config(desc); + + // Requires no TMA splits + DG_HOST_ASSERT(config.storage_config.swizzle_a_mode == config.layout.block_k); + DG_HOST_ASSERT(config.storage_config.swizzle_b_mode == config.layout.block_k); + const int load_block_m = config.storage_config.load_block_m; + const auto tensor_map_a = make_tma_3d_desc(a, k, m, batch_size, + config.layout.block_k, load_block_m, 1, + a.stride(1), + a.stride(0), + config.storage_config.swizzle_a_mode); + + const int load_block_n = config.storage_config.load_block_n; + const auto tensor_map_b = make_tma_3d_desc(b, k, n, batch_size, + config.layout.block_k, load_block_n, 1, + b.stride(1), + b.stride(0), + config.storage_config.swizzle_b_mode); + + const int store_block_m = config.storage_config.store_block_m; + const int store_block_n = config.storage_config.store_block_n; + const auto tensor_map_d = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.storage_config.swizzle_cd_mode); + + const auto tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.layout.block_m, config.layout.block_k, batch_size, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .gemm_desc = desc, + .gemm_config = config, + .launch_args = LaunchArgs(config.launch_config.num_sms, config.launch_config.num_threads, + config.pipeline_config.smem_size, + config.layout.get_cluster_size()), + .epilogue_type = std::nullopt, + .major_sfb = major_sfb, + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..c17d1b55 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_math_threads, num_tma_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_math_threads, args.num_tma_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_math_threads = 128; + constexpr int num_tma_threads = 128; + constexpr int num_threads = num_math_threads + num_tma_threads; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + // Only support small N for now + DG_HOST_ASSERT(n <= 32 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = num_stages * 2 * 8; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers; + + if (smem_size <= SM90ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + smem_size = SM90ArchSpec::smem_capacity; + + // Launch + const SM90BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_math_threads = num_math_threads, + .num_tma_threads = num_tma_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto code = SM90BF16HCPrenormGemmRuntime::generate(args); + const auto runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code); + SM90BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_clean_logits.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_clean_logits.hpp new file mode 100644 index 00000000..ebe4c7a6 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_clean_logits.hpp @@ -0,0 +1,81 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" + +namespace deep_gemm { + +class SMXXCleanLogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int next_n; + int seq_len; + int seq_len_kv; + uint64_t stride_logits; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + void* logits; + at::ScalarType logits_dtype; + + int block_kv; + int num_warps; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&smxx_clean_logits< + {}, {}, {}, {} + >); +}}; +)", args.next_n, args.block_kv, args.num_warps, to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, static_cast(args.stride_logits), + args.cu_seq_len_k_start, args.cu_seq_len_k_end, args.logits + )); + } +}; + +static void smxx_clean_logits(const torch::Tensor& logits, + const std::optional& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const int& next_n, + const int& seq_len, const int& seq_len_kv, + const uint64_t &stride_logits) { + const int block_kv = 8192; + const int num_warps = 8; + const int smem_size = block_kv * sizeof(float); + + // Launch + const SMXXCleanLogitsRuntime::Args& args = { + .next_n = next_n, + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .stride_logits = stride_logits, + .cu_seq_len_k_start = cu_seq_len_k_start.has_value() ? cu_seq_len_k_start.value().data_ptr() : nullptr, + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .logits_dtype = logits.scalar_type(), + .block_kv = block_kv, + .num_warps = num_warps, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_warps * 32, smem_size) + }; + const auto code = SMXXCleanLogitsRuntime::generate(args); + const auto runtime = compiler->build("smxx_clean_logits", code); + SMXXCleanLogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_cublaslt.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_cublaslt.hpp new file mode 100644 index 00000000..7f29b0a5 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_cublaslt.hpp @@ -0,0 +1,151 @@ +#pragma once + +#include +#include +#include +#include + +#include "../../jit/device_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/compatibility.hpp" + +namespace deep_gemm { + +static auto get_cublaslt_layout(const cudaDataType& type, const int& rows, const int& cols, const int& ld, + const std::optional& batch_count = std::nullopt, + const std::optional& batch_offset = std::nullopt) { + cublasLtMatrixLayout_t layout; + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutCreate(&layout, type, rows, cols, ld)); + if (batch_count.has_value()) { + DG_HOST_ASSERT(batch_offset.has_value()); + + const int64_t batch_offset_int64 = batch_offset.value(); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch_count.value(), sizeof(batch_count.value()))); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutSetAttribute(layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset_int64, sizeof(batch_offset_int64))); + } + return layout; +} + +static void call_cublaslt_api(const cublasOperation_t& trans_a, + const cublasOperation_t& trans_b, + const cublasLtMatrixLayout_t& layout_a, + const cublasLtMatrixLayout_t& layout_b, + const cublasLtMatrixLayout_t& layout_d, + const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const bool& accumulate) { + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + cudaDataType_t scale_type = CUDA_R_32F; + + // Operation description + cublasLtMatmulDesc_t desc; + DG_CUBLASLT_CHECK(cublasLtMatmulDescCreate(&desc, compute_type, scale_type)); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a))); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b))); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + +#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE + const int math_sms = device_runtime->get_num_sms(); + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms))); +#endif + +#if DG_FP8_COMPATIBLE and DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE + bool fp8_fast_accumulate = false; + if (a.scalar_type() == torch::kFloat8_e4m3fn) + DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate))); +#endif + + // Get cuBLASLt handle, workspace, and stream + const auto handle = device_runtime->get_cublaslt_handle(); + const auto workspace = device_runtime->get_cublaslt_workspace(); + const auto workspace_bytes = workspace.nbytes(); + const auto stream = at::cuda::getCurrentCUDAStream(); + + // Algorithm selection + cublasLtMatmulPreference_t pref; + cublasLtMatmulHeuristicResult_t heuristic; + int num_heuristic_results = 0; + uint32_t reduction_scheme_mask = CUBLASLT_REDUCTION_SCHEME_NONE | CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE; + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceCreate(&pref)); + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_bytes, sizeof(workspace_bytes))); + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceSetAttribute(pref, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, + &reduction_scheme_mask, sizeof(reduction_scheme_mask))); + DG_CUBLASLT_CHECK(cublasLtMatmulAlgoGetHeuristic(handle, desc, layout_a, layout_b, layout_d, layout_d, + pref, 1, &heuristic, &num_heuristic_results)); + DG_HOST_ASSERT(num_heuristic_results == 1 and "Unable to find any algorithm for the GEMM"); + + // Call: D = alpha * (A @ B) + beta * C + const float alpha = 1.0, beta = accumulate ? 1.0 : 0.0; + DG_CUBLASLT_CHECK(cublasLtMatmul(handle, // Light handle + desc, // Operation description + &alpha, // Alpha + b.data_ptr(), layout_a, // A + a.data_ptr(), layout_b, // B + &beta, // Beta + d.data_ptr(), layout_d, // C + d.data_ptr(), layout_d, // D + &heuristic.algo, // Algorithm + workspace.data_ptr(), workspace_bytes, // Workspace + stream)); // Stream + + // Free memory + DG_CUBLASLT_CHECK(cublasLtMatmulPreferenceDestroy(pref)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_a)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_b)); + DG_CUBLASLT_CHECK(cublasLtMatrixLayoutDestroy(layout_d)); + DG_CUBLASLT_CHECK(cublasLtMatmulDescDestroy(desc)); +} + +static void cublaslt_gemm(const torch::Tensor& lhs, const torch::Tensor& rhs, + const torch::Tensor& out, + const int& m, const int& n, const int& k, + const cute::UMMA::Major& a_major, const cute::UMMA::Major& b_major, + const bool& accumulate) { + const auto trans_a = b_major == cute::UMMA::Major::K ? CUBLAS_OP_T : CUBLAS_OP_N; + const auto trans_b = a_major == cute::UMMA::Major::K ? CUBLAS_OP_N : CUBLAS_OP_T; + + // Matrix layouts + const auto cuda_type_a = at::cuda::ScalarTypeToCudaDataType(rhs.scalar_type()); + const auto cuda_type_b = at::cuda::ScalarTypeToCudaDataType(lhs.scalar_type()); + const auto cuda_type_d = at::cuda::ScalarTypeToCudaDataType(out.scalar_type()); + const auto layout_a = b_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_a, k, n, rhs.stride(0)) + : get_cublaslt_layout(cuda_type_a, n, k, rhs.stride(1)); + const auto layout_b = a_major == cute::UMMA::Major::K ? get_cublaslt_layout(cuda_type_b, k, m, lhs.stride(0)) + : get_cublaslt_layout(cuda_type_b, m, k, lhs.stride(1)); + const auto layout_d = get_cublaslt_layout(cuda_type_d, n, m, out.stride(0)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, accumulate); +} + +static void cublaslt_bhr_hdr_bhd(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, + const int& b, const int& h, const int& r, const int& d) { + const auto m = d, n = b, k = r; + const auto trans_a = CUBLAS_OP_T; + const auto trans_b = CUBLAS_OP_N; + + // Matrix layouts + const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, k, m, rhs.stride(1), h, rhs.stride(0)); + const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); +} + + +static void cublaslt_bhd_hdr_bhr(const torch::Tensor& lhs, const torch::Tensor& rhs, const torch::Tensor& out, + const int& b, const int& h, const int& r, const int& d) { + const auto m = r, n = b, k = d; + const auto trans_a = CUBLAS_OP_N; + const auto trans_b = CUBLAS_OP_N; + + // Matrix layouts + const auto layout_a = get_cublaslt_layout(CUDA_R_16BF, m, k, rhs.stride(1), h, rhs.stride(0)); + const auto layout_b = get_cublaslt_layout(CUDA_R_16BF, k, n, lhs.stride(0), h, lhs.stride(1)); + const auto layout_d = get_cublaslt_layout(CUDA_R_16BF, m, n, out.stride(0), h, out.stride(1)); + + call_cublaslt_api(trans_a, trans_b, layout_a, layout_b, layout_d, lhs, rhs, out, false); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp new file mode 100644 index 00000000..3be10c98 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_fp4_mqa_logits.hpp @@ -0,0 +1,328 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXFP8MQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int seq_len; + int seq_len_kv; + int max_seqlen_k; + int stride_logits; + int num_heads, head_dim; + bool is_compressed_logits; + + int num_q_stages; + int num_kv_stages; + int block_q; + int block_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + void* logits; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_mqa_logits< + {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", arch, arch, + args.num_heads, args.head_dim, + args.is_compressed_logits, + args.block_q, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.launch_args.grid_dim.first, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, + args.max_seqlen_k, args.stride_logits, + args.cu_seq_len_k_start, args.cu_seq_len_k_end, + args.logits, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv, const torch::Tensor& kv_scales, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const torch::Tensor& logits, + const at::ScalarType& logits_dtype, + const int& seq_len, const int& seq_len_kv, + const int& max_seqlen_k, const int& stride_logits, + const int& num_heads, const int& head_dim, + const int& block_q, const int& block_kv) { + constexpr int num_specialized_threads = 128; + constexpr int num_q_stages = 3, num_kv_stages = 3; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); + + // Use compressed logits format when max_seqlen_k is specified + const bool is_compressed_logits = (max_seqlen_k > 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, + head_dim, block_q * num_heads, head_dim, head_dim); + const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv, + head_dim, block_kv, head_dim, head_dim); + // According to the driver API, the minimal alignment is 256 bytes + // So it is safe for us to do a 16-byte OOB + const auto tensor_map_kv_scales = make_tma_2d_desc(kv_scales, + get_tma_aligned_size(seq_len_kv, static_cast(kv_scales.element_size())), + 1, block_kv, 1, 0, 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len, + num_heads, block_q, num_heads, 0); + + // Calculate shared memory size + int smem_size = 0; + const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast(q.element_size()); + const int smem_weight_size_per_stage = block_q * num_heads * static_cast(weights.element_size()); + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv.element_size()); + const int kv_scale_size_per_stage = block_kv * static_cast(kv_scales.element_size()); + smem_size += num_q_stages * smem_q_size_per_stage; + smem_size += num_kv_stages * smem_kv_size_per_stage; + smem_size += num_q_stages * smem_weight_size_per_stage; + smem_size += num_kv_stages * kv_scale_size_per_stage; + smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8; + smem_size += 4; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXFP8MQALogitsRuntime::Args args = { + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .max_seqlen_k = max_seqlen_k, + .stride_logits = stride_logits, + .num_heads = num_heads, .head_dim = head_dim, + .is_compressed_logits = is_compressed_logits, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .block_q = block_q, + .block_kv = block_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SMXXFP8MQALogitsRuntime::generate(args); + const auto runtime = compiler->build("smxx_fp8_mqa_logits", code); + SMXXFP8MQALogitsRuntime::launch(runtime, args); +} + +class SM100FP4MQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int seq_len; + int seq_len_kv; + int max_seqlen_k; + int stride_logits; + int num_heads, head_dim; + bool is_compressed_logits; + + int num_q_stages; + int num_kv_stages; + int block_q; + int block_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + void* logits; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_sf_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_sf_kv; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp4_mqa_logits< + {}, {}, + {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", args.num_heads, args.head_dim, + args.is_compressed_logits, + args.block_q, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.launch_args.grid_dim.first, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, + args.max_seqlen_k, args.stride_logits, + args.cu_seq_len_k_start, args.cu_seq_len_k_end, + args.logits, + args.tensor_map_q, args.tensor_map_sf_q, + args.tensor_map_kv, args.tensor_map_sf_kv, + args.tensor_map_weights + )); + } +}; + +static void sm100_fp4_mqa_logits(const torch::Tensor& q, const torch::Tensor& sf_q, + const torch::Tensor& kv, const torch::Tensor& sf_kv, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const torch::Tensor& logits, + const at::ScalarType& logits_dtype, + const int& seq_len, const int& seq_len_kv, + const int& max_seqlen_k, const int& stride_logits, + const int& num_heads, const int& head_dim, + const int& block_q, const int& block_kv) { + constexpr int num_specialized_threads = 128; + const int num_math_threads = 2 * 128; + constexpr int num_q_stages = 3, num_kv_stages = 6, num_tmem_stages = 3; + + // Use compressed logits format when max_seqlen_k is specified + const bool is_compressed_logits = (max_seqlen_k > 0); + + // Construct TMAs + // `head_dim` must be 128 for 64B swizzling + DG_HOST_ASSERT(head_dim == 128); + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, + head_dim, block_q * num_heads, + static_cast(q.stride(1)), + head_dim / 2, 0, false, false); + const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, seq_len, + num_heads, block_q, + static_cast(sf_q.stride(0)), 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len, + num_heads, block_q, + static_cast(weights.stride(0)), 0); + const auto tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv, + head_dim, block_kv, + static_cast(kv.stride(0)), + head_dim / 2, 0, false, false); + // According to the driver API, the minimal alignment is 256 bytes + // So it is safe for us to do a 16-byte OOB + const auto tensor_map_sf_kv = make_tma_2d_desc(sf_kv, + get_tma_aligned_size(seq_len_kv, static_cast(sf_kv.element_size())), 1, + block_kv, 1, 0, 0); + + // Calculate shared memory size + const int smem_q_size_per_stage = block_q * num_heads * head_dim / 2; + const int smem_sf_q_size_per_stage = align(block_q * num_heads, 128) * sizeof(int); + const int smem_kv_size_per_stage = block_kv * head_dim / 2; + const int smem_sf_kv_size_per_stage = align(block_kv, 128) * sizeof(int); + const int smem_weight_size_per_stage = block_q * num_heads * sizeof(float); + + const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8; + const int smem_tmem_ptr = 4; + const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) + + smem_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SM100FP4MQALogitsRuntime::Args args = { + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .max_seqlen_k = max_seqlen_k, + .stride_logits = stride_logits, + .num_heads = num_heads, .head_dim = head_dim, + .is_compressed_logits = is_compressed_logits, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .block_q = block_q, + .block_kv = block_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_sf_q = tensor_map_sf_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_sf_kv = tensor_map_sf_kv, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SM100FP4MQALogitsRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp4_mqa_logits", code); + SM100FP4MQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp new file mode 100644 index 00000000..2a3288ee --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_fp4_paged_mqa_logits.hpp @@ -0,0 +1,463 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime { +public: + struct Args { + int aligned_batch_size; + int split_kv; + int num_sms; + bool is_varlen; + + int batch_size; + int next_n; + bool is_context_lens_2d; + int* context_lens; + int* indices; + int* schedule_metadata; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sched::smxx_paged_mqa_logits_metadata< + {}, {}, {}, {} + >); +}}; +)", args.aligned_batch_size, args.split_kv, args.num_sms, args.is_varlen ? "true" : "false"); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.next_n, + args.is_context_lens_2d, + args.context_lens, + args.indices, + args.schedule_metadata + )); + } +}; + +static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, + const torch::Tensor& schedule_metadata, + const int& batch_size, const int& next_n, + const int& block_kv, const int& num_sms, + const bool& is_context_lens_2d, + const bool& is_varlen, const int* indices_ptr) { + constexpr int split_kv = 256; + constexpr int num_threads = 32; + const int aligned_batch_size = align(batch_size, 32); + DG_HOST_ASSERT(split_kv % block_kv == 0); + + // Shared memory: prefix_sum[kAlignedBatchSize] + varlen_atom_token_start/context_len[kAlignedBatchSize] + varlen_num_atoms + const int smem_size = (3 * aligned_batch_size + 1) * static_cast(sizeof(int)); + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXPagedMQALogitsMetadataRuntime::Args& args = { + .aligned_batch_size = aligned_batch_size, + .split_kv = split_kv, + .num_sms = num_sms, + .is_varlen = is_varlen, + .batch_size = batch_size, + .next_n = next_n, + .is_context_lens_2d = is_context_lens_2d, + .context_lens = context_lens.data_ptr(), + .indices = const_cast(indices_ptr), + .schedule_metadata = schedule_metadata.data_ptr(), + .launch_args = LaunchArgs(1, num_threads, smem_size) + }; + const auto code = SMXXPagedMQALogitsMetadataRuntime::generate(args); + const auto runtime = compiler->build("smxx_paged_mqa_logits_metadata", code); + SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args); +} + +class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int batch_size; + int next_n; + int num_heads; + int head_dim; + int block_kv; + bool is_context_lens_2d; + bool is_varlen; + int block_table_stride; + int logits_stride; + + int num_q_stages; + int num_kv_stages; + int split_kv; + + int* context_lens; + void* logits; + int* block_table; + int* indices; + int* schedule_meta; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_paged_mqa_logits< + {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", arch, arch, + args.next_n, args.num_heads, + args.head_dim, args.block_kv, + args.is_context_lens_2d, args.is_varlen ? "true" : "false", + args.num_q_stages, args.num_kv_stages, + args.split_kv, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.logits_stride, args.block_table_stride, + args.context_lens, args.logits, + args.block_table, args.indices, args.schedule_meta, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv_cache, + const torch::Tensor& kv_cache_scales, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& logits, + const torch::Tensor& block_table, + const torch::Tensor& indices, + const torch::Tensor& schedule_meta, + const at::ScalarType& logits_dtype, + const int& batch_size, const int& next_n, + const int& num_heads, const int& head_dim, + const int& num_kv_blocks, const int& block_kv, + const bool& is_context_lens_2d, + const bool& is_varlen, + const int& logits_stride, + const int& block_table_stride, + const int& num_sms, + const int& split_kv) { + const int num_specialized_threads = 128; + const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64); + const int num_math_warp_groups = split_kv / mma_m; + const int num_math_threads = num_math_warp_groups * 128; + const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3); + DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0); + + // Construct TMAs + const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1; + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, + head_dim, next_n_atom * num_heads, + static_cast(q.stride(2)), + head_dim); + const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, + head_dim, block_kv, 1, + static_cast(kv_cache.stride(1)), + static_cast(kv_cache.stride(0)), + head_dim); + + const auto tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks, + block_kv, 1, + static_cast(kv_cache_scales.stride(0)), 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n, + num_heads, next_n_atom, + static_cast(weights.stride(0)), 0); + + // Calculate shared memory size + int smem_size = 0; + if (device_runtime->get_arch_major() == 9) { + const int swizzle_alignment = head_dim * 8; + + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); + + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); + const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); + const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); + + // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(next_n == 1 or next_n == 2); + } else { + const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim * static_cast(q.element_size()); + const int smem_kv_size_per_stage = split_kv * head_dim * static_cast(kv_cache.element_size()); + const int smem_kv_scale_size_per_stage = split_kv * static_cast(kv_cache_scales.element_size()); + const int smem_weight_size_per_stage = next_n_atom * num_heads * static_cast(weights.element_size()); + + const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8; + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) + + smem_barriers + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + } + + // Launch + const SMXXFP8PagedMQALogitsRuntime::Args args = { + .batch_size = batch_size, + .next_n = next_n, + .num_heads = num_heads, + .head_dim = head_dim, + .block_kv = block_kv, + .is_context_lens_2d = is_context_lens_2d, + .is_varlen = is_varlen, + .block_table_stride = block_table_stride, + .logits_stride = logits_stride, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .split_kv = split_kv, + .context_lens = context_lens.data_ptr(), + .logits = logits.data_ptr(), + .block_table = block_table.data_ptr(), + .indices = is_varlen ? indices.data_ptr() : nullptr, + .schedule_meta = schedule_meta.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_sms, + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SMXXFP8PagedMQALogitsRuntime::generate(args); + const auto runtime = compiler->build("smxx_fp8_paged_mqa_logits", code); + SMXXFP8PagedMQALogitsRuntime::launch(runtime, args); +} + +class SM100FP4PagedMQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int batch_size; + int next_n; + int num_heads; + int head_dim; + int block_kv; + bool is_context_lens_2d; + bool is_varlen; + int block_table_stride; + int logits_stride; + + int num_q_stages; + int num_kv_stages; + int split_kv; + + int* context_lens; + void* logits; + int* block_table; + int* indices; + int* schedule_meta; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_sf_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_sf_kv; + CUtensorMap tensor_map_weights; + at::ScalarType logits_dtype; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_fp4_paged_mqa_logits< + {}, {}, + {}, {}, + {}, {}, + {}, {}, + {}, + {}, {}, + {} + >); +}}; +)", args.next_n, args.num_heads, + args.head_dim, args.block_kv, + args.is_context_lens_2d, args.is_varlen ? "true" : "false", + args.num_q_stages, args.num_kv_stages, + args.split_kv, + args.num_specialized_threads, args.num_math_threads, + to_string(args.logits_dtype)); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.logits_stride, args.block_table_stride, + args.context_lens, args.logits, + args.block_table, args.indices, args.schedule_meta, + args.tensor_map_q, args.tensor_map_sf_q, + args.tensor_map_kv, args.tensor_map_sf_kv, + args.tensor_map_weights + )); + } +}; + +static void sm100_fp4_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& sf_q, + const torch::Tensor& kv_cache, + const torch::Tensor& kv_cache_sf, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& logits, + const torch::Tensor& block_table, + const torch::Tensor& indices, + const torch::Tensor& schedule_meta, + const at::ScalarType& logits_dtype, + const int& batch_size, const int& next_n, + const int& num_heads, const int& head_dim, + const int& num_kv_blocks, const int& block_kv, + const bool& is_context_lens_2d, + const bool& is_varlen, + const int& logits_stride, + const int& block_table_stride, + const int& num_sms, + const int& split_kv) { + const int num_specialized_threads = 128; + const int num_math_threads = 2 * 128; + DG_HOST_ASSERT(split_kv == 256 and logits_stride % split_kv == 0); + + // TODO: tuning num_stages + const int num_q_stages = 3, num_kv_stages = 10, num_tmem_stages = 3; + const int next_n_atom = (is_varlen or next_n >= 2) ? 2 : 1; + + // `head_dim` must be 128 for 64B swizzling + DG_HOST_ASSERT(head_dim == 128); + + // Using 2D TMA as tensor q is asserted contiguous + const auto tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, + head_dim, next_n_atom * num_heads, + static_cast(q.stride(2)), + head_dim / 2, 0, false, false); + // NOTES: `sf_q` is a 3D tensor, while `weights` is a 2D tensor + const auto tensor_map_sf_q = make_tma_2d_desc(sf_q, num_heads, batch_size * next_n, + num_heads, next_n_atom, + static_cast(sf_q.stride(1)), 0); + const auto tensor_map_weights = make_tma_2d_desc(weights, num_heads, batch_size * next_n, + num_heads, next_n_atom, + static_cast(weights.stride(0)), 0); + + const auto tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, + head_dim, block_kv, 1, + static_cast(kv_cache.stride(1)), + static_cast(kv_cache.stride(0)), + head_dim / 2, 0, false, false); + const auto tensor_map_sf_kv = make_tma_2d_desc(kv_cache_sf, block_kv, num_kv_blocks, + block_kv, 1, + static_cast(kv_cache_sf.stride(0)), 0); + + // Calculate shared memory size + const int smem_q_size_per_stage = next_n_atom * num_heads * head_dim / 2; + const int smem_sf_q_size_per_stage = align(next_n_atom * num_heads, 128) * sizeof(int); + const int smem_kv_size_per_stage = split_kv * head_dim / 2; + const int smem_sf_kv_size_per_stage = align(split_kv, 128) * sizeof(int); + const int smem_weight_size_per_stage = next_n_atom * num_heads * sizeof(float); + + const int smem_barriers = (num_q_stages + num_kv_stages + num_tmem_stages) * 2 * 8; + const int smem_tmem_ptr = 4; + const int smem_size = num_q_stages * (smem_q_size_per_stage + smem_sf_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_sf_kv_size_per_stage) + + smem_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SM100FP4PagedMQALogitsRuntime::Args args = { + .batch_size = batch_size, + .next_n = next_n, + .num_heads = num_heads, + .head_dim = head_dim, + .block_kv = block_kv, + .is_context_lens_2d = is_context_lens_2d, + .is_varlen = is_varlen, + .block_table_stride = block_table_stride, + .logits_stride = logits_stride, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .split_kv = split_kv, + .context_lens = context_lens.data_ptr(), + .logits = logits.data_ptr(), + .block_table = block_table.data_ptr(), + .indices = is_varlen ? indices.data_ptr() : nullptr, + .schedule_meta = schedule_meta.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_sf_q = tensor_map_sf_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_sf_kv = tensor_map_sf_kv, + .tensor_map_weights = tensor_map_weights, + .logits_dtype = logits_dtype, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_sms, + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto code = SM100FP4PagedMQALogitsRuntime::generate(args); + const auto runtime = compiler->build("sm100_fp4_paged_mqa_logits", code); + SM100FP4PagedMQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp new file mode 100644 index 00000000..f3b82e3d --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_mqa_logits.hpp @@ -0,0 +1,164 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXFP8MQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int seq_len; + int seq_len_kv; + int max_seqlen_k; + int stride_logits; + int num_heads, head_dim; + bool is_compressed_logits; + + int num_q_stages; + int num_kv_stages; + int block_q; + int block_kv; + + int* cu_seq_len_k_start; + int* cu_seq_len_k_end; + float* logits; + float softmax_scale; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_mqa_logits< + {}, {}, + {}, + {}, {}, + {}, {}, + {}, {} + >); +}}; +)", arch, arch, + args.num_heads, args.head_dim, + args.is_compressed_logits, + args.block_q, args.block_kv, + args.num_q_stages, args.num_kv_stages, + args.num_specialized_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.seq_len, args.seq_len_kv, + args.max_seqlen_k, static_cast(args.stride_logits), + args.cu_seq_len_k_start, args.cu_seq_len_k_end, + args.logits, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv, const torch::Tensor& kv_scales, + const torch::Tensor& weights, + const torch::Tensor& cu_seq_len_k_start, + const torch::Tensor& cu_seq_len_k_end, + const torch::Tensor& logits, + const int& seq_len, const int& seq_len_kv, + const int& max_seqlen_k, const int& stride_logits, + const int& num_heads, const int& head_dim, + const int& seq_len_alignment) { + constexpr int block_qh = 128; + constexpr int block_kv = 256; + constexpr int num_specialized_threads = 128; + constexpr int num_q_stages = 3, num_kv_stages = 3; + const int num_math_threads = (device_runtime->get_arch_major() == 10 ? 256 : 512); + const int block_q = block_qh / num_heads; + DG_HOST_ASSERT(block_qh % num_heads == 0); + DG_HOST_ASSERT(seq_len_alignment % block_q == 0); + + // Use compressed logits format when max_seqlen_k is specified + const bool is_compressed_logits = (max_seqlen_k > 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, seq_len * num_heads, + head_dim, block_qh, head_dim, head_dim); + const auto& tensor_map_kv = make_tma_2d_desc(kv, head_dim, seq_len_kv, + head_dim, block_kv, head_dim, head_dim); + // According to the driver API, the minimal alignment is 256 bytes + // So it is safe for us to do a 16-byte OOB + const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_scales, + get_tma_aligned_size(seq_len_kv, static_cast(kv_scales.element_size())), + 1, block_kv, 1, 0, 0); + const auto& tensor_map_weights = make_tma_2d_desc(weights, num_heads, seq_len, + num_heads, block_q, num_heads, 0); + + // Calculate shared memory size + int smem_size = 0; + const int smem_q_size_per_stage = block_q * num_heads * head_dim * static_cast(q.element_size()); + const int smem_weight_size_per_stage = block_q * num_heads * static_cast(weights.element_size()); + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv.element_size()); + const int kv_scale_size_per_stage = block_kv * static_cast(kv_scales.element_size()); + smem_size += num_q_stages * smem_q_size_per_stage; + smem_size += num_kv_stages * smem_kv_size_per_stage; + smem_size += num_q_stages * smem_weight_size_per_stage; + smem_size += num_kv_stages * kv_scale_size_per_stage; + smem_size += (num_q_stages * 2 + num_kv_stages * 2 + (num_math_threads / 128) * 2) * 8; + smem_size += 4; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXFP8MQALogitsRuntime::Args& args = { + .seq_len = seq_len, + .seq_len_kv = seq_len_kv, + .max_seqlen_k = max_seqlen_k, + .stride_logits = stride_logits, + .num_heads = num_heads, .head_dim = head_dim, + .is_compressed_logits = is_compressed_logits, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .block_q = block_q, + .block_kv = block_kv, + .cu_seq_len_k_start = cu_seq_len_k_start.data_ptr(), + .cu_seq_len_k_end = cu_seq_len_k_end.data_ptr(), + .logits = logits.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(device_runtime->get_num_sms(), + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto& code = SMXXFP8MQALogitsRuntime::generate(args); + const auto& runtime = compiler->build("smxx_fp8_mqa_logits", code); + SMXXFP8MQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp new file mode 100644 index 00000000..1240aad8 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -0,0 +1,265 @@ +#pragma once + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SMXXPagedMQALogitsMetadataRuntime final: public LaunchRuntime { +public: + struct Args { + int aligned_batch_size; + int split_kv; + int num_sms; + + int batch_size; + int next_n; + bool is_context_lens_2d; + int* context_lens; + int* schedule_metadata; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&smxx_paged_mqa_logits_metadata< + {}, {}, {} + >); +}}; +)", arch, args.aligned_batch_size, args.split_kv, args.num_sms); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + args.next_n, + args.is_context_lens_2d, + args.context_lens, + args.schedule_metadata + )); + } +}; + +static void smxx_paged_mqa_logits_metadata(const torch::Tensor& context_lens, + const torch::Tensor& schedule_metadata, + const int& batch_size, const int& next_n, + const int& block_kv, const int& num_sms, + const bool& is_context_lens_2d) { + constexpr int num_math_warpgroups = 4; + constexpr int num_threads = 32; + const int aligned_batch_size = align(batch_size, 32); + const int split_kv = block_kv * num_math_warpgroups; + + // Calculate shared memory size + const int smem_size = aligned_batch_size * static_cast(sizeof(int)); + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + + // Launch + const SMXXPagedMQALogitsMetadataRuntime::Args& args = { + .aligned_batch_size = aligned_batch_size, + .split_kv = split_kv, + .num_sms = num_sms, + .batch_size = batch_size, + .next_n = next_n, + .is_context_lens_2d = is_context_lens_2d, + .context_lens = context_lens.data_ptr(), + .schedule_metadata = schedule_metadata.data_ptr(), + .launch_args = LaunchArgs(1, num_threads, smem_size) + }; + const auto& code = SMXXPagedMQALogitsMetadataRuntime::generate(args); + const auto& runtime = compiler->build("smxx_paged_mqa_logits_metadata", code); + SMXXPagedMQALogitsMetadataRuntime::launch(runtime, args); +} + +class SMXXFP8PagedMQALogitsRuntime final: public LaunchRuntime { +public: + struct Args { + int batch_size; + int next_n; + int num_heads; + int head_dim; + int block_kv; + bool is_context_lens_2d; + int block_table_stride; + int logits_stride; + + int num_q_stages; + int num_kv_stages; + int split_kv; + + int* context_lens; + float* logits; + int* block_table; + int* schedule_meta; + + CUtensorMap tensor_map_q; + CUtensorMap tensor_map_kv; + CUtensorMap tensor_map_kv_scales; + CUtensorMap tensor_map_weights; + + int num_specialized_threads; + int num_math_threads; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + // TODO: optimize performance by tuning args + // Block sizes are fixed in this kernel + DG_HOST_ASSERT(128 % args.num_heads == 0); + const auto& arch = device_runtime->get_arch(true); + + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm{}_fp8_paged_mqa_logits< + {}, {}, + {}, {}, + {}, + {}, {}, + {}, + {}, {} + >); +}}; +)", arch, arch, + args.next_n, args.num_heads, + args.head_dim, args.block_kv, + args.is_context_lens_2d, + args.num_q_stages, args.num_kv_stages, + args.split_kv, + args.num_specialized_threads, args.num_math_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.batch_size, + static_cast(args.logits_stride), + static_cast(args.block_table_stride), + args.context_lens, args.logits, + args.block_table, args.schedule_meta, + args.tensor_map_q, args.tensor_map_kv, + args.tensor_map_kv_scales, args.tensor_map_weights + )); + } +}; + +static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, + const torch::Tensor& kv_cache, + const torch::Tensor& kv_cache_scales, + const torch::Tensor& weights, + const torch::Tensor& context_lens, + const torch::Tensor& logits, + const torch::Tensor& block_table, + const torch::Tensor& schedule_meta, + const int& batch_size, const int& next_n, + const int& num_heads, const int& head_dim, + const int& num_kv_blocks, const int& block_kv, + const bool& is_context_lens_2d, + const int& kv_cache_stride_bytes, + const int& logits_stride, + const int& block_table_stride, + const int& num_sms, + const int& split_kv) { + const int num_specialized_threads = 128; + const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64); + const int num_math_warp_groups = split_kv / mma_m; + const int num_math_threads = num_math_warp_groups * 128; + const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3); + DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0); + + // Construct TMAs + DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); + const auto& tensor_map_q = make_tma_2d_desc(q, head_dim, batch_size * next_n * num_heads, + head_dim, next_n * num_heads, head_dim, head_dim); + const auto& tensor_map_kv = make_tma_3d_desc(kv_cache, head_dim, block_kv, num_kv_blocks, + head_dim, block_kv, 1, + head_dim, kv_cache_stride_bytes, head_dim); + // TODO: use 1D TMA + const auto& tensor_map_kv_scales = make_tma_2d_desc(kv_cache_scales, block_kv, num_kv_blocks, + block_kv, 1, kv_cache_stride_bytes / static_cast(sizeof(float)), 0); + const auto& tensor_map_weights = make_tma_2d_desc(weights, next_n * num_heads, batch_size, + next_n * num_heads, 1, next_n * num_heads, 0); + + // Calculate shared memory size + int smem_size = 0; + if (device_runtime->get_arch_major() == 9) { + const int swizzle_alignment = head_dim * 8; + + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); + + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); + const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); + const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); + + // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + } else { + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int smem_kv_size_per_stage = split_kv * head_dim * static_cast(kv_cache.element_size()); + const int smem_kv_scale_size_per_stage = split_kv * static_cast(kv_cache_scales.element_size()); + const int smem_weight_size_per_stage = next_n * num_heads * static_cast(weights.element_size()); + + const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8; + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) + + smem_barriers + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + } + + // Launch + const SMXXFP8PagedMQALogitsRuntime::Args& args = { + .batch_size = batch_size, + .next_n = next_n, + .num_heads = num_heads, + .head_dim = head_dim, + .block_kv = block_kv, + .is_context_lens_2d = is_context_lens_2d, + .block_table_stride = block_table_stride, + .logits_stride = logits_stride, + .num_q_stages = num_q_stages, + .num_kv_stages = num_kv_stages, + .split_kv = split_kv, + .context_lens = context_lens.data_ptr(), + .logits = logits.data_ptr(), + .block_table = block_table.data_ptr(), + .schedule_meta = schedule_meta.data_ptr(), + .tensor_map_q = tensor_map_q, + .tensor_map_kv = tensor_map_kv, + .tensor_map_kv_scales = tensor_map_kv_scales, + .tensor_map_weights = tensor_map_weights, + .num_specialized_threads = num_specialized_threads, + .num_math_threads = num_math_threads, + .launch_args = LaunchArgs(num_sms, + num_specialized_threads + num_math_threads, + smem_size) + }; + const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args); + const auto& runtime = compiler->build("smxx_fp8_paged_mqa_logits", code); + SMXXFP8PagedMQALogitsRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_layout.hpp b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_layout.hpp new file mode 100644 index 00000000..5d1f17b5 --- /dev/null +++ b/third_party/DeepGEMM/csrc/jit_kernels/impls/smxx_layout.hpp @@ -0,0 +1,267 @@ +#pragma once + +#include + +#include "../../jit/kernel_runtime.hpp" +#include "../../jit/compiler.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../../utils/layout.hpp" + +namespace deep_gemm { + +class TransposeFP32Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_fp32< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); + } +}; + +class TransposeAndPackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int mn, sf_k; + int block_mn; + void *sf, *out; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&transpose_and_pack_fp32_into_ue8m0< + {}, {}, {} + >); +}}; +)", args.launch_args.num_threads, args.block_mn, args.sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, args.sf, args.out, static_cast(args.mn))); + } +}; + +class PackFP32IntoUE8M0Runtime final: public LaunchRuntime { +public: + struct Args { + int num_groups, mn, sf_k, packed_sf_k, gran_k; + int block_mn, block_packed_sf_k; + void *sf, *out, *ks; + + LaunchArgs launch_args; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&pack_fp32_into_ue8m0< + {}, {}, {}, {} + >); +}}; +)", args.num_groups, args.launch_args.num_threads, args.block_mn, args.block_packed_sf_k); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.sf, args.out, args.ks, args.mn, args.sf_k, args.packed_sf_k, args.gran_k)); + } +}; + +static std::tuple preprocess_sf(const torch::Tensor& sf) { + // NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA + const auto dim = sf.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + DG_HOST_ASSERT(sf.scalar_type() == torch::kFloat); + const auto batched_sf = dim == 2 ? sf.unsqueeze(0) : sf; + + const auto [num_groups, mn, sf_k] = get_shape<3>(batched_sf); + const auto tma_aligned_mn = get_tma_aligned_size(mn, static_cast(sf.element_size())); + return {dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf}; +} + +static torch::Tensor get_mn_major_tma_aligned_tensor(const torch::Tensor& sf) { + const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + + // The last kernel already gives a column-major TMA aligned layout + if ((batched_sf.stride(0) == tma_aligned_mn * sf_k or dim == 2) and batched_sf.stride(1) == 1 and batched_sf.stride(2) == tma_aligned_mn) + return (dim == 2) ? batched_sf.squeeze(0) : batched_sf; + + const auto out = torch::empty_strided({num_groups, mn, sf_k}, + {tma_aligned_mn * sf_k, 1, tma_aligned_mn}, + batched_sf.options()); + + if (not batched_sf.is_contiguous()) { + // Fallback to PyTorch's slow copy if not contiguous + // ReSharper disable once CppExpressionWithoutSideEffects + out.copy_(batched_sf); + } else { + constexpr int block_mn = 64; + constexpr int num_threads = 512; + const auto smem_size = block_mn * (sf_k + (1 - (sf_k % 2))) * static_cast(sizeof(float)); + const TransposeFP32Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, smem_size) + }; + + const auto code = TransposeFP32Runtime::generate(args); + const auto runtime = compiler->build("transpose_fp32", code); + TransposeFP32Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(const torch::Tensor& sf) { + const auto sf_reshaped = (sf.dim() == 2) ? sf.unsqueeze(0) : sf; + + // First, convert into UE8M0 `uint8_t` + const auto ue8m0_tensor = sf_reshaped.view(torch::kInt32).bitwise_right_shift(23).to(torch::kUInt8); + + // Second, make padded packed tensors + const auto [num_groups, mn, k] = get_shape<3>(sf_reshaped); + const auto aligned_mn = get_tma_aligned_size(mn, 4); + const auto aligned_k = align(k, 4); + + const auto options = torch::TensorOptions().device(sf.device()).dtype(torch::kUInt8); + auto padded = torch::zeros({num_groups, aligned_mn, aligned_k}, options); + // ReSharper disable once CppExpressionWithoutSideEffects + padded.slice(1, 0, mn).slice(2, 0, k).copy_(ue8m0_tensor); + padded = padded.view(-1).view(torch::kInt32).view({num_groups, aligned_mn, aligned_k / 4}); + + // Finally, transpose + auto out = torch::empty_strided({num_groups, aligned_mn, aligned_k / 4}, + {aligned_mn * (aligned_k / 4), 1, aligned_mn}, + at::TensorOptions().device(sf.device()).dtype(torch::kInt32)); + out = out.copy_(padded).slice(1, 0, mn); + return (sf.dim() == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf) { + const auto [dim, num_groups, mn, sf_k, tma_aligned_mn, batched_sf] = preprocess_sf(sf); + const auto packed_sf_k = ceil_div(sf_k, 4); + const auto out = torch::empty_strided({num_groups, mn, packed_sf_k}, + {packed_sf_k * tma_aligned_mn, 1, tma_aligned_mn}, + at::TensorOptions().device(batched_sf.device()).dtype(torch::kInt)); + // Launch the kernel + if (batched_sf.is_contiguous()) { + if ((mn * sf_k) % 4 != 0 and num_groups > 1) + return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + + constexpr int block_mn = 48; + constexpr int num_threads = 512; + const TransposeAndPackFP32IntoUE8M0Runtime::Args& args = { + .mn = mn, + .sf_k = sf_k, + .block_mn = block_mn, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), num_groups}, num_threads, block_mn * sf_k * 4) + }; + + const auto code = TransposeAndPackFP32IntoUE8M0Runtime::generate(args); + const auto runtime = compiler->build("transpose_and_pack_fp32_into_ue8m0", code); + TransposeAndPackFP32IntoUE8M0Runtime::launch(runtime, args); + } else { + if (mn % 4 != 0 or num_groups > 1) + return get_mn_major_tma_aligned_packed_ue8m0_tensor_torch(sf); + DG_HOST_ASSERT(batched_sf.stride(1) == 1 and batched_sf.stride(2) == mn); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = 1, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = batched_sf.data_ptr(), + .out = out.data_ptr(), + .ks = nullptr, + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto code = PackFP32IntoUE8M0Runtime::generate(args); + const auto runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + } + return (dim == 2) ? out.squeeze(0) : out; +} + +static torch::Tensor get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(const torch::Tensor& sf, + const torch::Tensor& ks_tensor, + const std::vector& ks, + const int gran_k) { + DG_HOST_ASSERT(gran_k == 32 or gran_k == 128); + const auto [sf_k, mn] = get_shape<2>(sf); + const auto num_groups = static_cast(ks.size()); + + int ref_sf_k = 0, packed_sf_k = 0; + for (const auto k: ks) + ref_sf_k += ceil_div(k, gran_k), packed_sf_k += ceil_div(k, gran_k * 4); + DG_HOST_ASSERT(sf.is_contiguous()); + DG_HOST_ASSERT(ref_sf_k == sf_k); + DG_HOST_ASSERT(num_groups <= 128 and mn % 4 == 0); + + const auto out = torch::empty({packed_sf_k, mn}, at::TensorOptions().device(sf.device()).dtype(torch::kInt)); + + constexpr int block_mn = 128; + constexpr int block_packed_sf_k = 16; + constexpr int num_threads = 512; + const PackFP32IntoUE8M0Runtime::Args& args = { + .num_groups = num_groups, + .mn = mn, + .sf_k = sf_k, + .packed_sf_k = packed_sf_k, + .gran_k = gran_k, + .block_mn = block_mn, + .block_packed_sf_k = block_packed_sf_k, + .sf = sf.data_ptr(), + .out = out.data_ptr(), + .ks = ks_tensor.data_ptr(), + .launch_args = LaunchArgs({ceil_div(mn, block_mn), ceil_div(packed_sf_k, block_packed_sf_k)}, num_threads) + }; + + const auto code = PackFP32IntoUE8M0Runtime::generate(args); + const auto runtime = compiler->build("pack_fp32_into_ue8m0", code); + PackFP32IntoUE8M0Runtime::launch(runtime, args); + return out; +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/python_api.cpp b/third_party/DeepGEMM/csrc/python_api.cpp new file mode 100644 index 00000000..a966afe1 --- /dev/null +++ b/third_party/DeepGEMM/csrc/python_api.cpp @@ -0,0 +1,28 @@ +#include +#include + +#include "apis/attention.hpp" +#include "apis/einsum.hpp" +#include "apis/hyperconnection.hpp" +#include "apis/gemm.hpp" +#include "apis/layout.hpp" +#include "apis/mega.hpp" +#include "apis/runtime.hpp" + +#ifndef TORCH_EXTENSION_NAME +#define TORCH_EXTENSION_NAME _C +#endif + +// ReSharper disable once CppParameterMayBeConstPtrOrRef +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "DeepGEMM C++ library"; + + // TODO: make SM80 incompatible issues raise errors + deep_gemm::attention::register_apis(m); + deep_gemm::einsum::register_apis(m); + deep_gemm::hyperconnection::register_apis(m); + deep_gemm::gemm::register_apis(m); + deep_gemm::layout::register_apis(m); + deep_gemm::mega::register_apis(m); + deep_gemm::runtime::register_apis(m); +} diff --git a/third_party/DeepGEMM/csrc/utils/compatibility.hpp b/third_party/DeepGEMM/csrc/utils/compatibility.hpp new file mode 100644 index 00000000..9e2d6720 --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/compatibility.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include +#include +#include + +// `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1 +#define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1)) + +// `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1 +#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010) + +// `cublasGetErrorString` is supported since CUDA Runtime API 11.4.2 +#define DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE (CUDART_VERSION >= 11042) + +// `CUBLASLT_MATMUL_DESC_FAST_ACCUM` and `CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET` are supported since CUDA Runtime API 11.8 +#define DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE (CUDART_VERSION >= 11080) \ No newline at end of file diff --git a/third_party/DeepGEMM/csrc/utils/exception.hpp b/third_party/DeepGEMM/csrc/utils/exception.hpp new file mode 100644 index 00000000..417dd3b4 --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/exception.hpp @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include + +#include "compatibility.hpp" + +namespace deep_gemm { + +class DGException final : public std::exception { + std::string message = {}; + +public: + explicit DGException(const char *name, const char* file, const int line, const std::string& error) { + message = std::string(name) + " error (" + file + ":" + std::to_string(line) + "): " + error; + } + + const char *what() const noexcept override { + return message.c_str(); + } +}; + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_HOST_ASSERT +#define DG_HOST_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + throw DGException("Assertion", __FILE__, __LINE__, #cond); \ + } \ +} while (0) +#endif + +#ifndef DG_HOST_UNREACHABLE +#define DG_HOST_UNREACHABLE(reason) (throw DGException("Assertion", __FILE__, __LINE__, reason)) +#endif + +#ifndef DG_NVRTC_CHECK +#define DG_NVRTC_CHECK(cmd) \ +do { \ + const auto e = (cmd); \ + if (e != NVRTC_SUCCESS) { \ + throw DGException("NVRTC", __FILE__, __LINE__, nvrtcGetErrorString(e)); \ + } \ +} while (0) +#endif + +#ifndef DG_CUDA_DRIVER_CHECK +#define DG_CUDA_DRIVER_CHECK(cmd) \ +do { \ + const auto e = (cmd); \ + if (e != CUDA_SUCCESS) { \ + std::stringstream ss; \ + const char *name, *info; \ + lazy_cuGetErrorName(e, &name), lazy_cuGetErrorString(e, &info); \ + ss << static_cast(e) << " (" << name << ", " << info << ")"; \ + throw DGException("CUDA driver", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +#ifndef DG_CUDA_RUNTIME_CHECK +#define DG_CUDA_RUNTIME_CHECK(cmd) \ +do { \ + const auto e = (cmd); \ + if (e != cudaSuccess) { \ + std::stringstream ss; \ + ss << static_cast(e) << " (" << cudaGetErrorName(e) << ", " << cudaGetErrorString(e) << ")"; \ + throw DGException("CUDA runtime", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +#ifndef DG_CUBLASLT_CHECK + +#if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE +inline const char* cublasGetStatusString(cublasStatus_t status) { + switch(status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + default: return "Unknown cuBLAS error"; + } +} +#endif + +#define DG_CUBLASLT_CHECK(cmd) \ +do { \ + const auto e = (cmd); \ + if (e != CUBLAS_STATUS_SUCCESS) { \ + std::ostringstream ss; \ + ss << static_cast(e) << " (" << cublasGetStatusString(e) << ")"; \ + throw DGException("cuBLASLt", __FILE__, __LINE__, ss.str()); \ + } \ +} while (0) +#endif + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/utils/format.hpp b/third_party/DeepGEMM/csrc/utils/format.hpp new file mode 100644 index 00000000..bf617372 --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/format.hpp @@ -0,0 +1,6 @@ +#pragma once + +// Just a wrapper for the `fmt` headers +#define FMT_HEADER_ONLY +#include +#include diff --git a/third_party/DeepGEMM/csrc/utils/hash.hpp b/third_party/DeepGEMM/csrc/utils/hash.hpp new file mode 100644 index 00000000..9efe6408 --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/hash.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include + +namespace deep_gemm { + +static uint64_t fnv1a(const std::vector& data, const uint64_t& seed) { + uint64_t h = seed; + const uint64_t prime = 0x100000001b3ull; + for (const char& c: data) { + h ^= static_cast(c); + h *= prime; + } + return h; +} + +static std::string get_hex_digest(const std::vector& data) { + const auto state_0 = fnv1a(data, 0xc6a4a7935bd1e995ull); + const auto state_1 = fnv1a(data, 0x9e3779b97f4a7c15ull); + + // Split-mix 64 + const auto split_mix = [](uint64_t z) { + z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9ull; + z = (z ^ (z >> 27)) * 0x94d049bb133111ebull; + return z ^ (z >> 31); + }; + + std::ostringstream oss; + oss << std::hex << std::setfill('0') + << std::setw(16) << split_mix(state_0) + << std::setw(16) << split_mix(state_1); + return oss.str(); +} + +static std::string get_hex_digest(const std::string& data) { + return get_hex_digest(std::vector{data.begin(), data.end()}); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/utils/layout.hpp b/third_party/DeepGEMM/csrc/utils/layout.hpp new file mode 100644 index 00000000..928472d3 --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/layout.hpp @@ -0,0 +1,119 @@ +#pragma once + +#include +#include + +#include "math.hpp" +#include "exception.hpp" +#include "../jit/device_runtime.hpp" + +namespace deep_gemm { + +// Major-ness stuffs +static void major_check(const torch::Tensor& t) { + const auto dim = t.dim(); + DG_HOST_ASSERT(dim == 2 or dim == 3); + if (dim == 3) + DG_HOST_ASSERT(t.stride(0) == t.size(-2) * t.size(-1)); + DG_HOST_ASSERT(t.stride(-2) == 1 or t.stride(-1) == 1); +} + +static cute::UMMA::Major get_major_type_ab(const torch::Tensor& t) { + major_check(t); + return t.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; +} + +static void check_major_type_cd(const torch::Tensor& t) { + // NOTES: the library only supports row-major output layouts + major_check(t); + DG_HOST_ASSERT(t.stride(-1) == 1); +} + +static bool fp8_requires_k_major() { + return device_runtime->get_arch_major() == 9; +} + +// Tensor utils +template +static auto get_shape(const torch::Tensor& t) { + DG_HOST_ASSERT(t.dim() == N); + return [&t] (std::index_sequence) { + return std::make_tuple(static_cast(t.sizes()[Is])...); + }(std::make_index_sequence()); +} + +static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [mn, k] = get_shape<2>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(mn, k); +} + +static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [num_groups, mn, k] = get_shape<3>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(num_groups, mn, k); +} + +// Recipe +static std::tuple +get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) { + const auto arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat); + return {1, 128, 128}; + } else if (arch_major == 10) { + DG_HOST_ASSERT(sfb_dtype == torch::kFloat or sfb_dtype == torch::kInt); + return sfb_dtype == torch::kFloat ? + std::make_tuple(1, 128, 128): // Legacy format + std::make_tuple(1, 1, 128); // 1D1D kernels + } + DG_HOST_UNREACHABLE("Unknown recipe"); +} + +// SF layouts +static torch::Tensor check_sf_layout(const torch::Tensor& sf, + const int& mn, const int& k, + const int& gran_mn, const int& gran_k, + const std::optional& num_groups, + const bool& tma_stride_check = false, + const bool& sm90_sfb_check = false, + const std::optional& type_check = std::nullopt) { + // Type check + if (type_check.has_value()) + DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); + + // Always do shape checks + const auto sf_dtype = sf.scalar_type(); + DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); + DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.size(-3) == num_groups.value()); + DG_HOST_ASSERT(sf.size(-2) == ceil_div(mn, gran_mn)); + DG_HOST_ASSERT(sf.size(-1) == ceil_div(k, gran_k * (sf_dtype == torch::kFloat ? 1 : 4))); + + // TMA stride checks: TMA aligned and MN-major + if (tma_stride_check) { + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.stride(-3) == sf.stride(-1) * sf.size(-1)); + // Check contiguity in the MN direction + DG_HOST_ASSERT(sf.stride(-2) == 1 or mn == 1); + DG_HOST_ASSERT(sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())); + } + + // SM90 SFB must be contiguous, or contiguous after transposing the last two dimensions + if (sm90_sfb_check) { + if (num_groups.has_value()) + DG_HOST_ASSERT(sf.stride(-3) == sf.size(-2) * sf.size(-1)); + DG_HOST_ASSERT((sf.stride(-1) == 1 and sf.stride(-2) == sf.size(-1)) or + (sf.stride(-1) == sf.size(-2) and sf.stride(-2) == 1)); + } + return sf; +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/utils/lazy_init.hpp b/third_party/DeepGEMM/csrc/utils/lazy_init.hpp new file mode 100644 index 00000000..386b1b45 --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/lazy_init.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +#define DG_DECLARE_STATIC_VAR_IN_CLASS(cls, name) decltype(cls::name) cls::name + +namespace deep_gemm { + +template +class LazyInit { +public: + explicit LazyInit(std::function()> factory) + : factory(std::move(factory)) {} + + T* operator -> () { + if (ptr == nullptr) + ptr = factory(); + return ptr.get(); + } + +private: + std::shared_ptr ptr; + std::function()> factory; +}; + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/utils/math.hpp b/third_party/DeepGEMM/csrc/utils/math.hpp new file mode 100644 index 00000000..19a86c38 --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/math.hpp @@ -0,0 +1,29 @@ +// TODO: merge this file with `math.cuh` (the device part) +#pragma once + +#include + +#include "exception.hpp" + +namespace deep_gemm { + +// TODO: use `torch::kFloat4_e2m1fn_x2` +constexpr auto kPackedFP4 = torch::kInt8; + +template +static T ceil_div(const T& a, const T& b) { + return (a + b - 1) / b; +} + +template +static constexpr T align(const T& a, const T& b) { + return ceil_div(a, b) * b; +} + +static int get_tma_aligned_size(const int& x, const int& element_size) { + constexpr int kNumTMAAlignmentBytes = 16; + DG_HOST_ASSERT(kNumTMAAlignmentBytes % element_size == 0); + return align(x, kNumTMAAlignmentBytes / element_size); +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/csrc/utils/system.hpp b/third_party/DeepGEMM/csrc/utils/system.hpp new file mode 100644 index 00000000..fda020be --- /dev/null +++ b/third_party/DeepGEMM/csrc/utils/system.hpp @@ -0,0 +1,128 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "exception.hpp" +#include "format.hpp" + +namespace deep_gemm { + +// ReSharper disable once CppNotAllPathsReturnValue +template +static dtype_t get_env(const std::string& name, const dtype_t& default_value = dtype_t()) { + const auto c_str = std::getenv(name.c_str()); + if (c_str == nullptr) + return default_value; + + // Read the env and convert to the desired type + if constexpr (std::is_same_v) { + return std::string(c_str); + } else if constexpr (std::is_same_v) { + int value; + std::sscanf(c_str, "%d", &value); + return value; + } else { + DG_HOST_ASSERT(false and "Unexpected type"); + } +} + +static std::tuple call_external_command(std::string command) { + command = command + " 2>&1"; + const auto deleter = [](FILE* f) { if (f) pclose(f); }; + std::unique_ptr pipe(popen(command.c_str(), "r"), deleter); + DG_HOST_ASSERT(pipe != nullptr); + + std::array buffer; + std::string output; + while (fgets(buffer.data(), buffer.size(), pipe.get())) + output += buffer.data(); + const auto status = pclose(pipe.release()); + // NOTES: if the child was killed by a signal (e.g., SIGINT from Ctrl+C), + // WEXITSTATUS would incorrectly return 0. Treat signal death as failure. + const auto exit_code = WIFEXITED(status) ? WEXITSTATUS(status) : 128 + WTERMSIG(status); + return {exit_code, output}; +} + +static std::vector collect_files(const std::filesystem::path& root) { + std::vector files; + std::function impl; + impl = [&](const std::filesystem::path& dir) { + for (const auto& entry: std::filesystem::directory_iterator(dir)) { + if (entry.is_directory()) { + impl(entry.path()); + } else if (entry.is_regular_file() and entry.path().extension() == ".cuh") { + files.emplace_back(entry.path()); + } + } + }; + impl(root); + + // Be consistent + std::sort(files.begin(), files.end()); + return files; +} + +static std::filesystem::path make_dirs(const std::filesystem::path& path) { + // OK if existed + std::error_code capture; + const bool created = std::filesystem::create_directories(path, capture); + if (not (created or capture.value() == 0)) { + DG_HOST_UNREACHABLE(fmt::format("Failed to make directory: {}, created: {}, value: {}", + path.c_str(), created, capture.value())); + } + if (created and get_env("DG_JIT_DEBUG")) + printf("Create directory: %s\n", path.c_str()); + return path; +} + +static std::string get_uuid() { + static std::random_device rd; + static std::mt19937 gen([]() { + return rd() ^ std::chrono::steady_clock::now().time_since_epoch().count(); + }()); + static std::uniform_int_distribution dist; + + std::stringstream ss; + ss << getpid() << "-" + << std::hex << std::setfill('0') + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen) << "-" + << std::setw(8) << dist(gen); + return ss.str(); +} + +static void safe_remove_all(const std::filesystem::path& path) { + std::error_code ec; + if (not std::filesystem::exists(path, ec) or ec) + return; + + // A single file + if (not std::filesystem::is_directory(path, ec) or ec) { + std::filesystem::remove(path, ec); + return; + } + + // Remove directory + auto it = std::filesystem::directory_iterator(path, + std::filesystem::directory_options::skip_permission_denied, ec); + for (auto end = std::filesystem::directory_iterator(); it != end and not ec;) { + const auto entry_path = it->path(); + + // Increase firstly to avoid failures + it.increment(ec); + if (ec) + break; + + // Recursively clean + safe_remove_all(entry_path); + } + std::filesystem::remove(path, ec); +} + +} // deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/__init__.py b/third_party/DeepGEMM/deep_gemm/__init__.py new file mode 100644 index 00000000..a9542e2f --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/__init__.py @@ -0,0 +1,126 @@ +import os +import subprocess +import torch + +# Set some default environment provided at setup +try: + # noinspection PyUnresolvedReferences + from .envs import persistent_envs + for key, value in persistent_envs.items(): + if key not in os.environ: + os.environ[key] = value +except ImportError: + pass + +# Configs +from . import _C +from ._C import ( + set_num_sms, + get_num_sms, + set_tc_util, + get_tc_util, + set_ignore_compile_dims, + set_block_size_multiple_of, + set_pdl, + get_pdl, +) + +# cuBLASLt Kernels +from ._C import ( + cublaslt_gemm_nt, cublaslt_gemm_nn, + cublaslt_gemm_tn, cublaslt_gemm_tt, +) + +try: + # DeepGEMM Kernels + from ._C import ( + # FP8 FP4 GEMMs + fp8_fp4_gemm_nt, fp8_fp4_gemm_nn, + fp8_fp4_gemm_tn, fp8_fp4_gemm_tt, + m_grouped_fp8_fp4_gemm_nt_contiguous, + m_grouped_fp8_fp4_gemm_nn_contiguous, + m_grouped_fp8_fp4_gemm_nt_masked, + # FP8 GEMMs + fp8_gemm_nt, fp8_gemm_nn, + fp8_gemm_tn, fp8_gemm_tt, + fp8_gemm_nt_skip_head_mid, + m_grouped_fp8_gemm_nt_contiguous, + m_grouped_fp8_gemm_nn_contiguous, + m_grouped_fp8_gemm_nt_masked, + k_grouped_fp8_gemm_nt_contiguous, + k_grouped_fp8_gemm_tn_contiguous, + # BF16 GEMMs + bf16_gemm_nt, bf16_gemm_nn, + bf16_gemm_tn, bf16_gemm_tt, + m_grouped_bf16_gemm_nt_contiguous, + m_grouped_bf16_gemm_nn_contiguous, + m_grouped_bf16_gemm_nt_masked, + k_grouped_bf16_gemm_tn_contiguous, + # Einsum kernels + einsum, + fp8_einsum, + # Attention kernels + fp8_fp4_mqa_logits, + get_paged_mqa_logits_metadata, + fp8_fp4_paged_mqa_logits, + # Attention kernels (legacy) + fp8_mqa_logits, + fp8_paged_mqa_logits, + # Hyperconnection kernels + tf32_hc_prenorm_gemm, + # Layout kernels + transform_sf_into_required_layout, + ) + + # Some alias for legacy supports + # TODO: remove these later + fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked + bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Mega kernels +from .mega import ( + SymmBuffer, + get_symm_buffer_for_mega_moe, + transform_weights_for_mega_moe, + fp8_fp4_mega_moe, +) + +# Some utils +from . import testing +from . import utils +from .utils import * + +# Legacy Triton kernels for A100 +try: + from . import legacy +except Exception as e: + print(f'Failed to load legacy DeepGEMM A100 Triton kernels: {e}') + +# Initialize CPP modules +def _find_cuda_home() -> str: + # TODO: reuse PyTorch API later + # For some PyTorch versions, the original `_find_cuda_home` will initialize CUDA, which is incompatible with process forks + cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') + if cuda_home is None: + # noinspection PyBroadException + try: + with open(os.devnull, 'w') as devnull: + nvcc = subprocess.check_output(['which', 'nvcc'], stderr=devnull).decode().rstrip('\r\n') + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + cuda_home = '/usr/local/cuda' + if not os.path.exists(cuda_home): + cuda_home = None + assert cuda_home is not None + return cuda_home + + +_C.init( + os.path.dirname(os.path.abspath(__file__)), # Library root directory path + _find_cuda_home() # CUDA home +) + +__version__ = '2.5.0' diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/comm/barrier.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/comm/barrier.cuh new file mode 100644 index 00000000..eb9858d8 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/comm/barrier.cuh @@ -0,0 +1,83 @@ +#pragma once + +#include + +#include +#include +#include + +namespace deep_gemm::comm { + +CUTLASS_DEVICE void cluster_sync_with_relaxed_arrive() { + // Perform cluster_sync with `barrier.cluster.arrive.relaxed` + // This is slightly faster than `cute::cluster_sync` but has weaker memory ordering guarantee + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); +} + +template +CUTLASS_DEVICE void grid_sync(const layout::Workspace& workspace, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope) { + // NOTES: the implementation idea is from `cooperative_groups::this_grid().sync()` + static constexpr uint32_t kFinishSumTag = 0x80000000u; + sync_scope(); + if (thread_idx == 0) { + const auto count_ptr = workspace.get_grid_sync_count_ptr(); + const auto old_value = ptx::atomic_add_rel( + count_ptr, sm_idx == 0 ? (kFinishSumTag - (kNumSMs - 1)) : 1); + uint32_t new_value; + do { + new_value = ptx::ld_acq(count_ptr); + } while (((new_value ^ old_value) & kFinishSumTag) == 0); + } + sync_scope(); +} + +template +CUTLASS_DEVICE void nvlink_barrier(const layout::Workspace& workspace, + const layout::SymBuffer& sym_buffer, + const uint32_t& sm_idx, const uint32_t& thread_idx, + const sync_scope_t& sync_scope, + const bool& sync_prologue = true, + const bool& sync_epilogue = true) { + DG_STATIC_ASSERT(kNumRanks <= kNumThreads, "Insufficient threads"); + + // Grid sync before NVLink signaling + if (sync_prologue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); + + // NVLink cross-rank barrier, only SM 0 participates + if (sm_idx == 0) { + auto* counter_ptr = workspace.get_nvl_barrier_counter_ptr(); + const auto status = (*counter_ptr) & 3; + const auto signal_phase = status & 1, signal_sign = status >> 1; + auto* signal_ptr = workspace.get_nvl_barrier_signal_ptr(signal_phase); + + // Send signals to remote ranks + if (thread_idx < kNumRanks) + ptx::red_add_rel_sys(sym_buffer.map(signal_ptr, thread_idx), signal_sign ? -1 : 1); + sync_scope(); + + // Update status and wait arrival (with 30s timeout, at 2 GHz) + constexpr int64_t kNumTimeoutCycles = 30ll * 2000000000ll; + if (thread_idx == 0) { + ptx::red_add(counter_ptr, 1); + const int target = signal_sign ? 0 : static_cast(kNumRanks); + const auto start_clock = clock64(); + while (ptx::ld_acq_sys(signal_ptr) != target) { + if (clock64() - start_clock >= kNumTimeoutCycles) { + printf("DeepGEMM NVLink barrier timeout (30s): rank=%d, counter=%d, signal=%d, target=%d, phase=%d, sign=%d, tag=%d\n", + sym_buffer.rank_idx, *counter_ptr, ptx::ld_acq_sys(signal_ptr), target, signal_phase, signal_sign, kTag); + DG_DEVICE_ASSERT(false and "NVLink barrier timeout"); + } + } + } + } + + // Grid sync after NVLink completion + if (sync_epilogue) + grid_sync(workspace, sm_idx, thread_idx, sync_scope); +} + +} // namespace deep_gemm::comm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/compile.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/compile.cuh new file mode 100644 index 00000000..e93c43fb --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/compile.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include + +#if defined(__NVCC__) or (defined(__clang__) and defined(__CUDA__)) or defined(__CUDACC_RTC__) or defined(__CLION_IDE__) +#define DG_IN_CUDA_COMPILATION +#endif + +#if defined(__NVCC__) || (defined(__clang__) and defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ __host__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE_NOINLINE __device__ +#define CUTLASS_DEVICE_NOINLINE __device__ +#else +#define CUTLASS_HOST_DEVICE_NOINLINE +#define CUTLASS_DEVICE_NOINLINE +#endif diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/cute_tie.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/cute_tie.cuh new file mode 100644 index 00000000..a3a8b62a --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/cute_tie.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include + +namespace cute { + +struct ignore_t { + template + constexpr const ignore_t& operator=(T&&) const noexcept { + return *this; + } +}; + +inline constexpr ignore_t ignore{}; + +} // namespace cute + +#define CUTE_TIE_CONCAT_IMPL(A, B) A##B +#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B) + +#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define CUTE_TIE_COUNT_ARGS(...) \ + CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get(TUPLE) +#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get(TUPLE) + +#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1); +#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2); +#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); +#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); +#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5); + +#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_DECL, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ) + +#define CUTE_TIE(TUPLE_EXPR, ...) \ + do { \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_ASSIGN, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ); \ + } while (0) diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh new file mode 100644 index 00000000..5f6a7a19 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/epilogue_utils.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +struct EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 + and kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/exception.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/exception.cuh new file mode 100644 index 00000000..78acf747 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/exception.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include +#include + +#ifdef __CLION_IDE__ + +CUTLASS_HOST_DEVICE void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +#ifndef DG_UNIFIED_ASSERT +#ifdef DG_IN_CUDA_COMPILATION +#define DG_UNIFIED_ASSERT(cond) DG_DEVICE_ASSERT(cond) +#else +#define DG_UNIFIED_ASSERT(cond) DG_HOST_ASSERT(cond) +#endif +#endif diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/math.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/math.cuh new file mode 100644 index 00000000..0f0d2504 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/math.cuh @@ -0,0 +1,149 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::math { + +/// Pointer operations +template +CUTLASS_HOST_DEVICE dtype_t* advance_ptr(void* ptr, const uint64_t num_bytes) { + return reinterpret_cast(static_cast(ptr) + num_bytes); +} + +/// Math functions +template +CUTLASS_HOST_DEVICE T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +CUTLASS_HOST_DEVICE T align(T a, T b) { + return (kDoCeilAlignment ? ceil_div(a, b) : (a / b)) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +CUTLASS_HOST_DEVICE constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; +} + +template +CUTLASS_DEVICE void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +#ifdef DG_IN_CUDA_COMPILATION +CUTLASS_DEVICE float2 fma2(const float2& a, const float2& b, const float2& c) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + return __ffma2_rn(a, b, c); +#else + return make_float2( + __fmaf_rn(a.x, b.x, c.x), + __fmaf_rn(a.y, b.y, c.y) + ); +#endif +} + +CUTLASS_HOST_DEVICE float fast_rcp(const float& x) { + float ret; + asm volatile("rcp.approx.ftz.f32 %0, %1;" : "=f"(ret) : "f"(x)); + return ret; +} + +/// Casting +template +CUTLASS_DEVICE int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +CUTLASS_DEVICE float fast_pow2(const int& x) { + uint32_t bits_x = (x + 127) << 23; + return *reinterpret_cast(&bits_x); +} + +CUTLASS_DEVICE int fast_log2_ceil(float x) { + const auto bits = *reinterpret_cast(&x); + const auto exp = bits >> 23; + const auto man = bits & ((1 << 23) - 1); + return exp - 127 + (man != 0); +} + +template +CUTLASS_DEVICE void get_e4m3_sf_and_sf_inv(const float2& amax, float2& sf, float2& sf_inv) { + DG_STATIC_ASSERT(kUseUE8M0, "Must use UE8M0"); + const float2 finfo_factor = {1.0 / 448.0, 1.0 / 448.0}; + const auto scaled = __fmul2_rn(amax, finfo_factor); + const auto exp_x = fast_log2_ceil(scaled.x); + const auto exp_y = fast_log2_ceil(scaled.y); + sf.x = fast_pow2(exp_x), sf_inv.x = fast_pow2(-exp_x); + sf.y = fast_pow2(exp_y), sf_inv.y = fast_pow2(-exp_y); +} + +/// Reduction +CUTLASS_DEVICE uint32_t warp_inclusive_sum(uint32_t value, const uint32_t& lane_idx) { + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t synced = __shfl_up_sync(0xffffffff, value, offset); + if (lane_idx >= offset) + value += synced; + } + return value; +} + +// Operation functors +template struct ReduceSum { CUTLASS_DEVICE T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { CUTLASS_DEVICE T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { CUTLASS_DEVICE T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { CUTLASS_DEVICE T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { CUTLASS_DEVICE T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +CUTLASS_DEVICE T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +CUTLASS_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} +#endif + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/reduction.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/reduction.cuh new file mode 100644 index 00000000..d9e35f73 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/reduction.cuh @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include + +#include + +// Operation functors +template struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +__forceinline__ __device__ T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +__forceinline__ __device__ T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/scheduler.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/scheduler.cuh new file mode 100644 index 00000000..f93b96ee --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -0,0 +1,288 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for countiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + __device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = __ldg(grouped_layout + group_idx); + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k, + int* grouped_layout = nullptr) { + num_m_blocks = ceil_div(shape_m, BLOCK_M); + num_n_blocks = ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = __ldg(grouped_layout); + num_m_blocks = ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto& group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + current_group_idx)), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = align(current_psum_m, 128u); + current_psum_m = __ldg(grouped_layout + current_group_idx); + current_m_block_cumsum += num_m_blocks; + num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with 128 + m_block_idx += last_psum_m / BLOCK_M; + DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M"); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto& block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/sm100_utils.cuh new file mode 100644 index 00000000..537cbe08 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -0,0 +1,266 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace deep_gemm::sm100 { + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + uint32_t stride_byte_offset, uint32_t leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +__device__ __forceinline__ +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +__device__ __forceinline__ +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +__device__ __forceinline__ +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +template +__device__ __forceinline__ +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto& layout_type = to_umma_layout_type(); + const auto& num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +__device__ __forceinline__ +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +template +__device__ constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if (kNumCols <= 32) return 32; + if (kNumCols <= 64) return 64; + if (kNumCols <= 128) return 128; + if (kNumCols <= 256) return 256; + return 512; +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +__device__ __forceinline__ +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +} // namespace `deep_gemm::sm100` diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/sm90_utils.cuh new file mode 100644 index 00000000..0874b675 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -0,0 +1,332 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm::sm90 { + +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + + template + __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +struct SM90_U32x2_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +__forceinline__ __device__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +__forceinline__ __device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +template +__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +__device__ __forceinline__ +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto& layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +} // namespace `deep_gemm::sm90` diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/tma_copy.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/tma_copy.cuh new file mode 100644 index 00000000..2c5bf708 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/tma_copy.cuh @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include + +#include + +namespace deep_gemm::tma { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE void +copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +} // namespace deep_gemm::tma diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/tma_utils.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/tma_utils.cuh new file mode 100644 index 00000000..bd54adc2 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/tma_utils.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +// Tensormap related +__device__ __forceinline__ void tensor_map_release_cta() { + asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +} + +__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +} // namespace `deep_gemm` diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/types.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/types.cuh new file mode 100644 index 00000000..e07df0af --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/types.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr CUTLASS_HOST_DEVICE int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr CUTLASS_HOST_DEVICE bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/types.hpp b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/types.hpp new file mode 100644 index 00000000..410c5469 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/types.hpp @@ -0,0 +1,41 @@ +#pragma once + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/utils.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/utils.cuh new file mode 100644 index 00000000..3a5f7ad6 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/common/utils.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include + +namespace deep_gemm::utils { + +template +struct PatternVisitor { + FuncT func; + + CUTLASS_HOST_DEVICE + explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} + + CUTLASS_HOST_DEVICE + auto operator [](const uint32_t& i) const { + return func(i); + } +}; + +template +struct Vectorized { + static auto zeros() { + // TODO: add `ulonglong4` for SM100 once `__ldg` support this + if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) { + return make_uint4(0, 0, 0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) { + return make_uint2(0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) { + return 0; + } else { + DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization"); + } + } + + using vec_t = decltype(zeros()); +}; + +template +CUTLASS_DEVICE constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if constexpr (kNumCols <= 32) return 32; + if constexpr (kNumCols <= 64) return 64; + if constexpr (kNumCols <= 128) return 128; + if constexpr (kNumCols <= 256) return 256; + return 512; +} + +} // namespace deep_gemm::utils diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh new file mode 100644 index 00000000..bf0e460c --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd.cuh @@ -0,0 +1,137 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M waves + constexpr auto kNumMWaves = BLOCK_M / STORE_BLOCK_M; + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]); + + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = base_m_idx + w * STORE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(base_n_idx + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = tmem_base_addr + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = smem_base_ptr + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared( + smem_ptr, + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]) + ); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_base_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +} + +} // namespace deep_gemm::epilogue diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh new file mode 100644 index 00000000..f3f5351e --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/sm100_store_cd_swap_ab.cuh @@ -0,0 +1,144 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm::epilogue { + +template +CUTLASS_DEVICE void +sm100_store_cd_swap_ab(const utils::PatternVisitor& smem_cd, uint32_t& tma_stage_idx, + const uint32_t& tmem_base_addr, + const uint32_t& base_m_idx, const uint32_t& base_n_idx, const uint32_t& batch_idx, + const uint32_t& effective_m, + const uint32_t& epilogue_warp_idx, const uint32_t& lane_idx, + const cutlass::arch::ClusterTransactionBarrier* tmem_empty_barrier, + const cute::TmaDescriptor& tensor_map_cd) { + // NOTES: The epilogue requires a full warpgroup to read all 128 TMEM rows, + // implying STORE_BLOCK_N must be 128. + DG_STATIC_ASSERT(STORE_BLOCK_N == 128, "STORE_BLOCK_N must be 128 to match TMEM rows"); + + // TMA checks + constexpr uint32_t STORE_BLOCK_N_ATOM = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumSwizzleAtomRows = 8; + DG_STATIC_ASSERT(kSwizzleCDMode == 128, "TMA D must be 128B swizzled"); + DG_STATIC_ASSERT(BLOCK_M % STORE_BLOCK_M == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(STORE_BLOCK_M % kNumSwizzleAtomRows == 0, "Invalid swizzling"); + DG_STATIC_ASSERT(STORE_BLOCK_N % STORE_BLOCK_N_ATOM == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Iterate over M blocks + const auto num_stores = effective_m / STORE_BLOCK_M; + for (uint32_t s = 0; s < num_stores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / kNumSwizzleAtomRows; ++ i) { + uint32_t tmem_addr = tmem_base_addr + + s * STORE_BLOCK_M + // Store stage offset + i * kNumSwizzleAtomRows; // In-block offset + uint32_t values[kNumSwizzleAtomRows]; + + // Warps cooperatively write an atomic block to shared memory + DG_STATIC_ASSERT(STORE_BLOCK_N_ATOM % 32 == 0, "Invalid block sizes"); + constexpr uint32_t kNumWarpsPerAtom = STORE_BLOCK_N_ATOM / 32; + uint32_t outer_atom_offset = (epilogue_warp_idx / kNumWarpsPerAtom) * STORE_BLOCK_M * kSwizzleCDMode; + uint32_t inner_atom_offset = i * kNumSwizzleAtomRows * kSwizzleCDMode; + auto smem_base_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + outer_atom_offset + inner_atom_offset; + + if constexpr (cute::is_same_v) { + // NOTES: Swizzling is not required in this case, but used here for consistency with other cases + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + uint32_t col = lane_idx / 4; + + #pragma unroll + for (uint32_t row = 0; row < kNumSwizzleAtomRows; ++ row) { + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes + + (lane_idx % 4) * sizeof(float); + ptx::st_shared(reinterpret_cast(smem_ptr), values[row]); + } + } else { + // Load from TMEM using `.16x256b` shape to satisfy STSM layout requirements + // Start from lane index 0 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + // Start from lane index 16 + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Destination shared memory address + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + auto smem_ptr = smem_base_ptr + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + + // Store matrix with transposition + ptx::SM90_U32x4_STSM_T::copy(math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (s == num_stores - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barrier->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / STORE_BLOCK_N_ATOM; ++ i) { + auto smem_ptr = smem_cd[tma_stage_idx] + i * STORE_BLOCK_M * STORE_BLOCK_N_ATOM; + uint32_t m_idx = base_m_idx + s * STORE_BLOCK_M; + uint32_t n_idx = epilogue_type_t::apply_index_n(base_n_idx + i * STORE_BLOCK_N_ATOM); + + // Issue 2D or 3D TMA store + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx, batch_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, n_idx, m_idx); + } + } + cute::tma_store_arrive(); + } + __syncwarp(); + } +} + +} // namespace deep_gemm::epilogue diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/transform.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/transform.cuh new file mode 100644 index 00000000..0266f4d4 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/epilogue/transform.cuh @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace deep_gemm::epilogue::transform { + +struct EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + CUTLASS_DEVICE static uint32_t apply_index_n(const uint32_t& n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 and + kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +} // namespace deep_gemm::epilogue::transform diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh new file mode 100644 index 00000000..a60e2de8 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -0,0 +1,437 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `umma_arrive()` overhead + constexpr bool kDoMergeStages = + kNumStages_ >= 8 and kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 8; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // MMA Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 16; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory out of bound for UMMA"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * UMMA_N; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive only at the leader CTA + full_barriers[i]->init(kNumMulticast); + // Arrive at all CTAs + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + if constexpr (kTensorCoreUtilControl < 100) + tensor_core_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = (stage_idx + 1) % kNumStages; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc() + : cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + ptx::tcgen05_after_thread_sync(); + + // UMMA and empty barrier arrival alias + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + + // Launch MMAs + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + if (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc); + } + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + // For utilization control + umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + __syncwarp(); + + // Wait for last UMMA to be done + tensor_core_full_barrier->wait(tensor_core_phase); + tensor_core_phase ^= 1; + + // Sleep for certain cycles + constexpr static uint64_t kNumUMMACycles = (2ull * UMMA_M * UMMA_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + ptx::tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx< + (not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } + } + } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh new file mode 100644 index 00000000..13bb0872 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -0,0 +1,271 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumThreads, 1) +sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); + DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Shared memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + + // Fill D/A/B + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); + }); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Block indices + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == 0) { + // TMA load warp + for (uint32_t s = 0; s < num_total_stages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + uint32_t m_idx = BLOCK_M * m_block_idx; + uint32_t n_idx = BLOCK_N * n_block_idx; + uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + uint32_t k_idx = sk_idx % SHAPE_K; + uint32_t s_idx = sk_idx / SHAPE_K; + + // Issue TMAs + if (cute::elect_one_sync()) { + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (cute::elect_one_sync()) + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } else if (warp_idx == 1) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + constexpr uint32_t UMMA_M = LAYOUT_AD_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Wait tensor memory empty barrier arrival + ptx::tcgen05_after_thread_sync(); + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + a_desc.lo = mma::sm100::advance_umma_desc_lo( + a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo( + b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + if (warp_idx == 2) + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + ptx::tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (warp_idx == 0 and cute::elect_one_sync()) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = m_block_idx * BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is doing TMA stores + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh new file mode 100644 index 00000000..b8a99fd0 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp4_mqa_logits.cuh @@ -0,0 +1,457 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, + const uint32_t logits_stride, + const uint32_t* cu_seq_len_k_start, + const uint32_t* cu_seq_len_k_end, + logits_dtype_t* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQ = math::constexpr_align(BLOCK_Q * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(BLOCK_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQ = BLOCK_Q * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(BLOCK_KV == kNumMathWarpGroups * UMMA_M and BLOCK_KV % kNumUTCCPAlignedElems == 0, "Invalid `BLOCK_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQ * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = BLOCK_Q * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQ / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + + // Allocate tensor memory + if (warp_idx == kSpecWarpStart + 2) + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + __syncthreads(); + + // Scheduler + const uint32_t num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + auto load_schedule = [&](const uint32_t& q_idx) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto row_idx = cute::min(q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cute::min(cu_seq_len_k_start[row_idx], seq_len_kv); + seq_k_end[i] = cute::min(cu_seq_len_k_end[row_idx], seq_len_kv); + start = cute::min(start, seq_k_start[i]); + end = cute::max(end, seq_k_end[i]); + } + // TMA alignment requirements for SF KV + start = start / 4 * 4; + return {start, math::ceil_div(end - start, BLOCK_KV)}; + }; + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + // Enumerate Q blocks + if (cute::elect_one_sync()) { + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_idx * BLOCK_Q); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_idx * BLOCK_Q); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQ * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx], 0, kv_start + kv_idx * BLOCK_KV); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx], + kv_start + kv_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQ / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize umma desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma( + a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this into `deep_gemm/ptx/tcgen05.cuh` + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline. Without this, UMMA can consume + // kNumQStages Q blocks before math warps release any, causing a + // circular dependency: UMMA waits full_q -> TMA_Q waits empty_q + // -> Math waits full_tmem -> UMMA (already moved on). + empty_q_barriers[q_stage_idx]->arrive(); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = threadIdx.x; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr uint32_t N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[BLOCK_Q][kNumHeads]; + + // Enumerate Q blocks + for (uint32_t q_idx = sm_idx; q_idx < num_q_blocks; q_idx += kNumSMs) { + // Load KV block ranges + CUTE_TIE_DECL(load_schedule(q_idx), kv_start, num_kv_blocks); + + // Wait TMA Q arrivals + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + // TODO: optimize bank conflicts + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + // Enumerate KV blocks + for (uint32_t kv_idx = 0; kv_idx < num_kv_blocks; ++ kv_idx) { + // Calculate KV offset in advance + auto kv_offset = kv_start + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + } + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + // TODO: optimize performance + const auto q_offset = (q_idx * BLOCK_Q + i) * static_cast(logits_stride); + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; + } else { + logits[q_offset + kv_offset] = result; + } + __syncwarp(); + } + } + + // Release last Q empty + empty_q_barriers[q_stage_idx]->arrive(); + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh new file mode 100644 index 00000000..d9add534 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp4_paged_mqa_logits.cuh @@ -0,0 +1,510 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp4_paged_mqa_logits(const uint32_t batch_size, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_sf_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_sf_q); + cute::prefetch_tma_descriptor(&tensor_map_weights); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_sf_kv); + } + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); + + // UMMA configs + static constexpr uint32_t kNumTmemStages = 3; + static constexpr uint32_t kNumUTCCPAlignedElems = 128; + static constexpr uint32_t UMMA_M = 128; + static constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; + static constexpr uint32_t UMMA_K = 64; + static constexpr uint32_t kNumSFQAtom = math::constexpr_align(kNextNAtom * kNumHeads, kNumUTCCPAlignedElems); + static constexpr uint32_t kNumSFKV = math::constexpr_align(SPLIT_KV, kNumUTCCPAlignedElems); + static constexpr uint32_t kRealNumSFQAtom = kNextNAtom * kNumHeads; + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == kNumMathWarpGroups * UMMA_M and SPLIT_KV % kNumUTCCPAlignedElems == 0, "Invalid `SPLIT_KV`"); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = 8 * (kHeadDim / 2); + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_Q_SIZE_PER_STAGE = kNumSFQAtom * sizeof(int); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * (kHeadDim / 2); + static constexpr uint32_t SMEM_SF_KV_SIZE_PER_STAGE = kNumSFKV * sizeof(int); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * i; + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i; + }); + const auto smem_sf_ptr = smem_buffer + (SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages); + auto smem_sf_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * i); + }); + auto smem_sf_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_sf_ptr + SMEM_SF_Q_SIZE_PER_STAGE * kNumQStages + SMEM_SF_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto tmem_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + i; }); + auto empty_tmem_barriers = utils::PatternVisitor([&](const uint32_t& i) { return tmem_barrier_ptr + kNumTmemStages + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(tmem_barrier_ptr + kNumTmemStages * 2); + + // Tensor memory configs + constexpr uint32_t kNumAccumTmemCols = kNextNAtom * kNumHeads * kNumTmemStages; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFQ = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFKV = kNumAccumTmemCols + kNumSFQAtom / 32; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 2) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumTmemStages; ++i) { + full_tmem_barriers[i]->init(1); + empty_tmem_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + using Scheduler = sched::PagedMQALogitsScheduler; + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Make Q, KV and TMEM pipeline + auto make_pipeline = [](const uint32_t& num_stages) { + // Return current stage and phase, and advance pipeline by steps + return [iter_idx = 0u, num_stages](const uint32_t& step = 1) mutable -> cute::tuple { + uint32_t current_idx = iter_idx; + iter_idx += step; + return {current_idx % num_stages, (current_idx / num_stages) & 1}; + }; + }; + auto advance_q_pipeline = make_pipeline(kNumQStages); + auto advance_kv_pipeline = make_pipeline(kNumKVStages); + auto advance_tmem_pipeline = make_pipeline(kNumTmemStages); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading Q + cutlass::arch::warpgroup_reg_dealloc(); + + if (cute::elect_one_sync()) { + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + // Initialize outside valid range to indicate no previous task + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, _, __; + while (scheduler.fetch_next_task(q_atom_idx, _, __)) { + // Issue TMA Q when (q_idx, atom_idx) changes + if (q_atom_idx != last_q_atom_idx) { + // Wait Q consumer release + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + const auto q_token_idx = Scheduler::atom_to_token_idx(q_atom_idx); + cute::SM90_TMA_LOAD_2D::copy(&tensor_map_q, reinterpret_cast(full_q_barriers[q_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_q[q_stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_sf_q, full_q_barriers[q_stage_idx], smem_sf_q[q_stage_idx], 0, q_token_idx); + tma::copy(&tensor_map_weights, full_q_barriers[q_stage_idx], smem_weights[q_stage_idx], 0, q_token_idx); + full_q_barriers[q_stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + kRealNumSFQAtom * sizeof(int) + SMEM_WEIGHT_SIZE_PER_STAGE); + } + last_q_atom_idx = q_atom_idx; + } + } + __syncwarp(); + } else if (warp_idx == kSpecWarpStart + 1) { + // TMA warp for loading KV cache + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + // Persistently schedule over blocks + uint32_t kv_block_idx_ptr = 32, kv_block_idx_storage; + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, num_kv; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, num_kv)) { + // Reset block table cache on kv restart + if (q_atom_idx != last_q_atom_idx) + kv_block_idx_ptr = 32; + last_q_atom_idx = q_atom_idx; + + // Coalesced load of block table + if (kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; + } + __syncwarp(); + + // Broadcast KV block indices + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `SPLIT_KV`"); + + // Wait KV consumer release + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + + // Issue TMA KV + if (cute::elect_one_sync()) { + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(&tensor_map_kv, reinterpret_cast(full_kv_barriers[kv_stage_idx]), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim / 2) * i, + 0, 0, kv_block_idx[i]); + tma::copy(&tensor_map_sf_kv, full_kv_barriers[kv_stage_idx], + smem_sf_kv[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_SF_KV_SIZE_PER_STAGE); + } + } + } else if (warp_idx == kSpecWarpStart + 2) { + // UMMA warp + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + // Wait TMA Q arrivals + uint32_t q_stage_idx, q_phase; + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Transpose and copy SF Q + #pragma unroll + for (uint32_t i = 0; i < kNumSFQAtom / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_q[q_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + if (cute::elect_one_sync()) + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFQ + i * 4); + __syncwarp(); + } + } + last_q_atom_idx = q_atom_idx; + + // Wait TMA KV arrivals + CUTE_TIE_DECL(advance_kv_pipeline(), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Transpose + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + utccp_required_smem_warp_transpose(smem_ptr); + cutlass::arch::fence_view_async_shared(); + } + + // UMMA with SF + if (cute::elect_one_sync()) { + // Copy SF KV + #pragma unroll + for (uint32_t i = 0; i < kNumSFKV / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sf_kv[kv_stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute::SM100_UTCCP_4x32dp128bit_1cta::copy(sf_desc, kTmemStartColOfSFKV + i * 4); + } + + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + // Wait TMEM release + CUTE_TIE_DECL(advance_tmem_pipeline(), tmem_stage_idx, tmem_phase); + uint32_t tmem_addr = tmem_stage_idx * UMMA_N; + + empty_tmem_barriers[tmem_stage_idx]->wait(tmem_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA with SF + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto runtime_instr_desc = mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k * 2, k * 2); + // TODO: generalize UMMA desc + DG_STATIC_ASSERT(kHeadDim == 128, "Invalid head dim"); + auto a_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_kv[kv_stage_idx] + i * UMMA_M * (kHeadDim / 2) + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + auto b_desc = mma::sm100::make_smem_desc( + cute::UMMA::LayoutType::SWIZZLE_64B, + smem_q[q_stage_idx] + k * UMMA_K / 2, + 8 * (kHeadDim / 2), 0); + ptx::SM100_MMA_MXF4_SS::fma(a_desc, b_desc, tmem_addr, k, runtime_instr_desc, + kTmemStartColOfSFKV + i * 4, kTmemStartColOfSFQ); + } + // TODO: move this PTX into headers + asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64 [%0];" + ::"r"(cute::cast_smem_ptr_to_uint(full_tmem_barriers[tmem_stage_idx]))); + } + } + cutlass::arch::umma_arrive(reinterpret_cast(empty_kv_barriers[kv_stage_idx])); + } + } else if (warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + + const auto math_warpgroup_idx = warpgroup_idx; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Math warpgroups process TMEM stages alternately + // Advance pipeline to align with the assigned stage + advance_tmem_pipeline(math_warpgroup_idx); + + // Local register buffers + float accum[kNumHeads]; + float weights[kNextNAtom][kNumHeads]; + + // Persistently schedule over blocks + uint32_t last_q_atom_idx = batch_size * kNumNextNAtoms; + uint32_t q_atom_idx, kv_idx, _; + bool is_paired_atom = false; + while (scheduler.fetch_next_task(q_atom_idx, kv_idx, _)) { + if (q_atom_idx != last_q_atom_idx) { + CUTE_TIE_DECL(advance_q_pipeline(), q_stage_idx, q_phase); + + // Release last Q empty + if (last_q_atom_idx != batch_size * kNumNextNAtoms) + empty_q_barriers[(q_stage_idx + kNumQStages - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrivals + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + float4 raw = ptx::ld_shared((float4*)(smem_weights[q_stage_idx] + i * kNumHeads + j)); + weights[i][j + 0] = raw.x; + weights[i][j + 1] = raw.y; + weights[i][j + 2] = raw.z; + weights[i][j + 3] = raw.w; + } + } + + // Check if this atom pairs two tokens from the same sequence + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(q_atom_idx, batch_size) == 2); + } + } + last_q_atom_idx = q_atom_idx; + + // Calculate KV offset in advance + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV + math_thread_idx; + + // Advance pipeline by `kNumMathWarpGroups` steps + // Wait UMMA arrival + CUTE_TIE_DECL(advance_tmem_pipeline(kNumMathWarpGroups), tmem_stage_idx, tmem_phase); + full_tmem_barriers[tmem_stage_idx]->wait(tmem_phase); + ptx::tcgen05_after_thread_sync(); + + // Reduce over the head dim and store + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + + // Only loop over valid iterations + #pragma unroll + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + uint32_t tmem_addr = tmem_stage_idx * UMMA_N + i * kNumHeads; + tmem_load(cute::Int{}, tmem_addr, accum); + tmem_load(cute::Int{}, tmem_addr + kNumHeads / 2, accum + kNumHeads / 2); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(sum.x + sum.y); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride)] = result; + __syncwarp(); + } + + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_tmem_barriers[tmem_stage_idx]->arrive(); + }; + + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); + } + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh new file mode 100644 index 00000000..0bc6a3fe --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_gemm_1d1d.cuh @@ -0,0 +1,514 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_fp4_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // MMA Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * kNumMulticast; + constexpr uint32_t UMMA_N = kSwapAB ? BLOCK_M : BLOCK_N; + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT((kSwapAB and BLOCK_N == LAYOUT_AD_M) or + (not kSwapAB and (BLOCK_M == 32 or BLOCK_M == 64 or BLOCK_M == LAYOUT_AD_M)), "Invalid block size"); + + // SF configs + constexpr uint32_t kNumUTCCPAlignedElems = 128; + constexpr uint32_t SF_BLOCK_M = math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = math::constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + DG_STATIC_ASSERT((kGemmType != GemmType::KGroupedContiguous) or kGranKA == kGranKB, "K-grouped SF requires kGranKA == kGranKB"); + + // Epilogue configs + // Always enable pipeline for better performance + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + // NOTES: To maximize epilogue threads utilization, process an entire BLOCK_N + // per store stage for swap-AB cases, and an entire BLOCK_M for non-swap cases + constexpr uint32_t STORE_BLOCK_M = kSwapAB ? 16 : cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwapAB ? BLOCK_N : kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = kSwapAB ? kNumEpilogueThreads: STORE_BLOCK_M; + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * STORE_BLOCK_N * sizeof(cd_dtype_t); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + // NOTES: Make sure we have enough shared memory for UMMA padding + constexpr uint32_t UMMA_A_SIZE_PER_STAGE = math::constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Synchronize the cluster before 2-CTA TMEM allocation + kNumMulticast > 1 ? cute::cluster_sync() : void(); + + // Utils + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranKA * 4); + const auto shape_sfb_k = math::ceil_div(shape_k, kGranKB * 4); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = reinterpret_cast(smem_b[kNumStages]); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_sfb[kNumStages]);; + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler( + shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Use dynamic load block M, when swap-AB is enabled + const auto load_block_m = kSwapAB ? scheduler.get_aligned_effective_m_in_block(m_block_idx) / kNumMulticast : LOAD_BLOCK_M; + + // For k-grouped layout, the number of block K is variable + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), sched::IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * load_block_m) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + uint32_t sfa_m_idx = m_block_idx * BLOCK_M; + uint32_t sfa_k_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::SF_K>( + shape_sfa_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad)); + tma::copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = scheduler.template get_global_idx( + shape_sfb_k, 1, math::ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx); + tma::copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + auto instr_desc = kSwapAB ? cute::UMMA::make_instr_desc_block_scaled() + : cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Dynamic update of UMMA N based on effective M, when swap-AB is enabled + if constexpr (kSwapAB) { + uint32_t umma_n = scheduler.get_aligned_effective_m_in_block(m_block_idx); + mma::sm100::update_instr_desc_with_umma_n(instr_desc, umma_n); + } + + // Launch MMAs + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + #pragma unroll 4 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // Do SF copy at certain stages + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + + // Issue UMMA + using mma_t = cute::conditional_t< + kNumMulticast == 1, ptx::SM100_MMA_MXF8F6F4_SS, ptx::SM100_MMA_MXF8F6F4_2x1SM_SS>; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto runtime_instr_desc = kSwapAB ? + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfb_id, sfa_id): + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + a_desc.lo = mma::sm100::advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + if constexpr (kSwapAB) { + mma_t::fma(b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } else { + mma_t::fma(a_desc, b_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFA, kTmemStartColOfSFB); + } + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ptx::ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + ptx::st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + ptx::tcgen05_after_thread_sync(); + + const auto tmem_base_addr = accum_stage_idx * UMMA_N; + const auto base_m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto base_n_idx = n_block_idx * BLOCK_N; + + if constexpr (kSwapAB) { + const auto effective_m = scheduler.get_aligned_effective_m_in_block(m_block_idx); + epilogue::sm100_store_cd_swap_ab< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + effective_m, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } else { + epilogue::sm100_store_cd< + BLOCK_M, BLOCK_N, STORE_BLOCK_M, STORE_BLOCK_N, + kSwizzleCDMode, kNumTMAStoreStages, kNumUMMAStoreThreads, + kGemmType, kWithAccumulation, + cd_dtype_t, epilogue_type_t> + (smem_cd, tma_stage_idx, tmem_base_addr, + base_m_idx, base_n_idx, scheduler.current_group_idx, + epilogue_warp_idx, lane_idx, + tmem_empty_barriers[accum_stage_idx], + tensor_map_cd); + } + } + } + + // TODO: Remove redundant synchronization + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Deallocate tensor memory + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh new file mode 100644 index 00000000..b2adc6c7 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_fp4_mega_moe.cuh @@ -0,0 +1,1380 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template < + uint32_t kNumMaxTokensPerRank, + uint32_t kHidden, uint32_t kIntermediateHidden, + uint32_t kNumExperts, uint32_t kNumTopk, + uint32_t kNumExpertsPerWave, + uint32_t BLOCK_M, uint32_t BLOCK_N, uint32_t BLOCK_K, + uint32_t STORE_BLOCK_M, + uint32_t SF_BLOCK_M, uint32_t SF_BLOCK_N, + uint32_t kNumMaxPoolTokens, + uint32_t kNumPaddedSFPoolTokens, + uint32_t kNumStages, + uint32_t kNumDispatchThreads, uint32_t kNumNonEpilogueThreads, + uint32_t kNumEpilogueThreads, + uint32_t kNumSMs, uint32_t kNumRanks, + float kActivationClamp, + bool kFastMath, + uint32_t L1_SHAPE_N = kIntermediateHidden * 2, + uint32_t L1_SHAPE_K = kHidden, + uint32_t L2_SHAPE_N = kHidden, + uint32_t L2_SHAPE_K = kIntermediateHidden, + uint32_t kNumDispatchWarps = kNumDispatchThreads / 32, + uint32_t kNumMMANonEpilogueWarps = kNumNonEpilogueThreads / 32, + uint32_t kNumEpilogueWarps = kNumEpilogueThreads / 32, + uint32_t kNumEpilogueWarpgroups = kNumEpilogueWarps / 4, + uint32_t kNumThreads = kNumDispatchThreads + kNumNonEpilogueThreads + kNumEpilogueThreads, + uint32_t kNumTokensPerWarp = 32 / kNumTopk, + uint32_t kNumExpertsPerRank = kNumExperts / kNumRanks +> +CUTLASS_GLOBAL __launch_bounds__(kNumThreads, 1) void +sm100_fp8_fp4_mega_moe_impl(void* y, + int* cumulative_local_expert_recv_stats, + const uint32_t num_tokens, + const __grid_constant__ layout::SymBuffer sym_buffer, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_weights_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l1_output, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_acts_sf, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights, + const __grid_constant__ cute::TmaDescriptor tensor_map_l2_weights_sf) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::TMEM::Allocator2Sm; + + // Template checks + DG_STATIC_ASSERT(kNumDispatchThreads % 128 == 0, "Invalid number of dispatch threads"); + DG_STATIC_ASSERT(kNumNonEpilogueThreads == 128, "Invalid number of MMA non-epilogue threads"); + DG_STATIC_ASSERT(kNumEpilogueThreads % 128 == 0, "Invalid number of MMA epilogue and combine threads"); + DG_STATIC_ASSERT(kNumExperts % kNumRanks == 0, "Invalid number of experts or ranks"); + + // Thread indices + const bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const uint32_t sm_idx = blockIdx.x; + const uint32_t thread_idx = threadIdx.x; + const uint32_t warp_idx = cutlass::canonical_warp_idx_sync(); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0) { + cute::prefetch_tma_descriptor(&tensor_map_l1_acts); + cute::prefetch_tma_descriptor(&tensor_map_l1_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights); + cute::prefetch_tma_descriptor(&tensor_map_l1_weights_sf); + cute::prefetch_tma_descriptor(&tensor_map_l1_output); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts); + cute::prefetch_tma_descriptor(&tensor_map_l2_acts_sf); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights); + cute::prefetch_tma_descriptor(&tensor_map_l2_weights_sf); + } + + // Workspaces + const auto workspace = layout::Workspace( + sym_buffer.get_base_ptr(), kNumRanks, kNumExperts, kNumMaxTokensPerRank, kNumTopk); + + // Token and buffer layouts + constexpr auto fp8_token_layout = layout::Data(kHidden); + constexpr auto bf16_token_layout = layout::Data(kHidden * sizeof(nv_bfloat16)); + constexpr auto fp8_intermediate_token_layout = layout::Data(kIntermediateHidden); + constexpr auto fp8_sf_layout = layout::Data(kHidden / 32); + constexpr auto fp8_intermediate_sf_layout = layout::Data(kIntermediateHidden / 32); + constexpr auto input_topk_idx_layout = layout::Data(kNumTopk * sizeof(int64_t), false); + constexpr auto input_topk_weights_layout = layout::Data(kNumTopk * sizeof(float), false); + constexpr auto l1_topk_weights_layout = layout::Data(sizeof(float), false); + + // Registered inputs + const auto input_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxTokensPerRank, + workspace.get_end_ptr()); + const auto input_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumMaxTokensPerRank, + input_token_buffer.get_end_ptr()); + const auto input_topk_idx_buffer = layout::Buffer( + input_topk_idx_layout, 1, kNumMaxTokensPerRank, + input_sf_buffer.get_end_ptr()); + const auto input_topk_weights_buffer = layout::Buffer( + input_topk_weights_layout, 1, kNumMaxTokensPerRank, + input_topk_idx_buffer.get_end_ptr()); + + // SF and its buffer configs + constexpr uint32_t kGranK = 32; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(SF_BLOCK_M == math::constexpr_align(BLOCK_M, kNumUTCCPAlignedElems), "Invalid SF_BLOCK_M"); + DG_STATIC_ASSERT(SF_BLOCK_N == BLOCK_N, "No padding is needed for SFB"); + + // UTCCP 4x32 transpose index mapping within each 128-element group + const auto transform_sf_token_idx = [](const uint32_t& token_idx_in_expert) { + const uint32_t idx = token_idx_in_expert % BLOCK_M; + return token_idx_in_expert / BLOCK_M * SF_BLOCK_M + + (idx & ~127u) + (idx & 31u) * 4 + ((idx >> 5) & 3u); + }; + + // L1 inputs + const auto l1_token_buffer = layout::Buffer( + fp8_token_layout, 1, kNumMaxPoolTokens, + input_topk_weights_buffer.get_end_ptr()); + const auto l1_sf_buffer = layout::Buffer( + fp8_sf_layout, 1, kNumPaddedSFPoolTokens, + l1_token_buffer.get_end_ptr()); + const auto l1_topk_weights_buffer = layout::Buffer( + l1_topk_weights_layout, 1, kNumMaxPoolTokens, + l1_sf_buffer.get_end_ptr()); + + // L2 inputs + const auto l2_token_buffer = layout::Buffer( + fp8_intermediate_token_layout, 1, kNumMaxPoolTokens, + l1_topk_weights_buffer.get_end_ptr() + ); + const auto l2_sf_buffer = layout::Buffer( + fp8_intermediate_sf_layout, 1, kNumPaddedSFPoolTokens, + l2_token_buffer.get_end_ptr() + ); + + // Combine inputs + const auto combine_token_buffer = layout::Buffer( + bf16_token_layout, kNumTopk, kNumMaxTokensPerRank, + l2_sf_buffer.get_end_ptr() + ); + + // Data types + // NOTES: activations are FP8 (e4m3), weights are FP4 (e2m1) + using a_dtype_t = cutlass::float_e4m3_t; + using b_dtype_t = cutlass::detail::float_e2m1_unpacksmem_t; + + // MMA configs + // NOTES: always swap A/B, 2-CTA MMA, and matrices are K-major + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t UMMA_M = LAYOUT_AD_M * 2; + constexpr uint32_t UMMA_N = BLOCK_M; // Swap AB + constexpr uint32_t UMMA_K = 32; + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / 2; // Multicast on A + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N; + DG_STATIC_ASSERT(BLOCK_M % 16 == 0, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_N == LAYOUT_AD_M, "Invalid block N"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + + // Swizzle configs + constexpr uint32_t kSwizzleAMode = BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t kSwizzleBMode = BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t kSwizzleCDMode = 128; + DG_STATIC_ASSERT(BLOCK_N % kSwizzleCDMode == 0, "Invalid block N"); + + // Epilogue configs + constexpr uint32_t kNumEpilogueStages = 2; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Shared memory + constexpr uint32_t kSharedMemoryAlignment = 1024; + extern __shared__ __align__(kSharedMemoryAlignment) uint8_t smem_buffer[]; + + // Shared memory sizes + // NOTES: FP8 CD output for L1 (2 TMA stages, BLOCK_N/2 post-SwiGLU), BF16 output for L2 (no TMA, a single stage) + constexpr uint32_t L1_OUT_BLOCK_N = BLOCK_N / 2; + constexpr uint32_t SMEM_EXPERT_COUNT_SIZE = + math::constexpr_align(kNumExperts * sizeof(uint32_t), kSharedMemoryAlignment); + constexpr uint32_t SMEM_SEND_BUFFER_SIZE = + math::constexpr_align(fp8_token_layout.get_num_bytes() * kNumDispatchWarps, kSharedMemoryAlignment); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + constexpr uint32_t SMEM_CD_L1_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * L1_OUT_BLOCK_N * sizeof(cutlass::float_e4m3_t) * kNumTMAStoreStages; + constexpr uint32_t SMEM_CD_L2_SIZE = + kNumEpilogueWarpgroups * STORE_BLOCK_M * BLOCK_N * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_L1_SIZE > SMEM_CD_L2_SIZE ? SMEM_CD_L1_SIZE : SMEM_CD_L2_SIZE; + constexpr uint32_t SMEM_CD_L1_SIZE_PER_STAGE = SMEM_CD_L1_SIZE / kNumTMAStoreStages; + constexpr uint32_t SMEM_BEFORE_BARRIER_SIZE = + SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + DG_STATIC_ASSERT(SMEM_CD_SIZE % kSharedMemoryAlignment == 0 and + SMEM_A_SIZE_PER_STAGE % kSharedMemoryAlignment == 0 and + SMEM_B_SIZE_PER_STAGE % kSharedMemoryAlignment == 0, + "Shared memory of CD/A/B must be aligned to 1024 bytes"); + + // Tensor memory size + constexpr uint32_t kNumAccumTmemCols = UMMA_N * kNumEpilogueStages; + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Assign shared memory for dispatch warps + const auto smem_expert_count = reinterpret_cast(smem_buffer); + const auto smem_send_buffers = layout::Buffer( + fp8_token_layout, kNumDispatchWarps, 1, + math::advance_ptr(smem_buffer, SMEM_EXPERT_COUNT_SIZE)); + + // GEMM shared memory: C/D, A, B + // NOTES: GEMM shared memory starts after the dispatch region, aligned to 1024 bytes + auto smem_gemm_base = math::advance_ptr( + smem_buffer, SMEM_EXPERT_COUNT_SIZE + SMEM_SEND_BUFFER_SIZE + ); + + // D/A/B shared memory + auto smem_cd = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, i * SMEM_CD_L1_SIZE_PER_STAGE); + }); + auto smem_cd_l2 = smem_cd[0]; + auto smem_a = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([=](const uint32_t& i) { + return math::advance_ptr(smem_gemm_base, SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SF shared memory: SFA and SFB per pipeline stage + auto sf_start_ptr = math::advance_ptr(smem_gemm_base, + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto smem_sfa = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = utils::PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Epilogue amax reduction shared memory + auto smem_amax_reduction = reinterpret_cast(smem_sfb[kNumStages]); + + // Barriers and tensor memory pointer + auto barrier_start_ptr = reinterpret_cast(smem_amax_reduction + STORE_BLOCK_M * kNumEpilogueWarps / 2); + auto dispatch_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages + i); }); + auto tmem_full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + i); }); + auto tmem_empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages + i); }); + auto combine_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + i); }); + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumDispatchWarps + kNumStages * 2 + kNumEpilogueStages * 2 + kNumEpilogueWarps * 2); + + // A cluster sync is essential for 2CTA tensor memory allocation + comm::cluster_sync_with_relaxed_arrive(); + + // Initialization + if (warp_idx == 0) { + // Clean shared memory + if (cute::elect_one_sync()) + ptx::st_shared_bulk(smem_expert_count, kNumExperts * sizeof(uint32_t)); + } else if (warp_idx == 1) { + // Init m-barriers for dispatch + #pragma unroll + for (uint32_t i = lane_idx; i < kNumDispatchWarps; i += 32) + dispatch_barriers[i]->init(1); + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Init GEMM barriers + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(2 * 2); + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(2 * kNumEpilogueThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueWarps * 2; ++ i) + combine_barriers[i]->init(1); + } + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 3) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + // NOTES: Using `.relaxed` is allowed here since `fence_barrier_init` is `.release.cluster`, + // and `barrier.cluster.wait.aligned` is by default `.acquire` + comm::cluster_sync_with_relaxed_arrive(); + + // Task scheduler + auto scheduler = sched::MegaMoEScheduler< + BLOCK_M, BLOCK_N, BLOCK_K, + L1_SHAPE_N, L1_SHAPE_K, + L2_SHAPE_N, L2_SHAPE_K, + kNumExpertsPerRank, + kNumExpertsPerWave, + kNumSMs, kNumRanks>(workspace); + + // MMA pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Intra-SM Barrier indices + constexpr uint32_t kDispatchBarrierIdx = 0; + constexpr uint32_t kDispatchWithEpilogueBarrierIdx = 1; + constexpr uint32_t kEpilogueFullBarrierIdx = 2; + constexpr uint32_t kEpilogueWGBarrierStartIdx = 3; + + // NVLink barrier tags + constexpr uint32_t kBeforeDispatchPullBarrierTag = 1; + constexpr uint32_t kBeforeCombineReduceBarrierTag = 2; + constexpr uint32_t kAfterWorkspaceCleanBarrierTag = 3; + + // Adjust registers + constexpr uint32_t kNumDispatchRegisters = 48; + constexpr uint32_t kNumNonEpilogueRegisters = 40; + constexpr uint32_t kNumEpilogueRegisters = 208; + DG_STATIC_ASSERT(kNumDispatchRegisters * kNumDispatchThreads + + kNumNonEpilogueRegisters * kNumNonEpilogueThreads + + kNumEpilogueRegisters * kNumEpilogueThreads <= 64512, + "Too many registers"); + + // Grid sync index assignments (dispatch and epilogue use separate counters to avoid conflicts) + constexpr uint32_t kDispatchGridSyncIndex = 0; + constexpr uint32_t kEpilogueGridSyncIndex = 1; + + // Different warp roles + if (warp_idx < kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // Dispatch warps + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + constexpr uint32_t kNumActivateLanes = kNumTokensPerWarp * kNumTopk; + const auto read_topk_idx = [&](const auto& process) { + // TODO: figure out better unrolling + // Now, `unroll` is better than `unroll 8` + #pragma unroll + for (uint32_t i = (sm_idx * kNumDispatchWarps + warp_idx) * kNumTokensPerWarp; + i < num_tokens; + i += kNumSMs * kNumDispatchWarps * kNumTokensPerWarp) { + // Allocate slots for each token-topk + int expert_idx = -1; + if (i + (lane_idx / kNumTopk) < num_tokens and lane_idx < kNumActivateLanes) { + expert_idx = static_cast( + __ldg(input_topk_idx_buffer.get_base_ptr() + i * kNumTopk + lane_idx)); + if (expert_idx >= 0) + process(i * kNumTopk + lane_idx, expert_idx); + } + __syncwarp(); + } + }; + + // Count experts' tokens + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + atomicAdd_block(smem_expert_count + expert_idx, 1); + }); + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Get SM offset (~6.5 us) + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const uint64_t send_value = (1ull << 32) | static_cast(smem_expert_count[i]); + smem_expert_count[i] = static_cast( + ptx::atomic_add(workspace.get_expert_send_count_ptr(i), send_value)); + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Write source indices (~2 us with 512 tokens) + read_topk_idx([&](const uint32_t& token_topk_idx, const int& expert_idx) { + const auto dst_rank_idx = expert_idx / kNumExpertsPerRank; + const auto dst_slot_idx = atomicAdd_block(smem_expert_count + expert_idx, 1); + const auto dst_ptr = workspace.get_src_token_topk_idx_ptr( + expert_idx % kNumExpertsPerRank, sym_buffer.rank_idx, dst_slot_idx); + *sym_buffer.map(dst_ptr, dst_rank_idx) = token_topk_idx; + }); + + // Grid sync + comm::grid_sync( + workspace, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); } + ); + + // Write expert count + if (sm_idx == 0) { + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) { + const auto dst_rank_idx = i / kNumExpertsPerRank; + const auto dst_local_expert_idx = i % kNumExpertsPerRank; + const auto expert_status = *workspace.get_expert_send_count_ptr(i); + *sym_buffer.map( + workspace.get_expert_recv_count_ptr(sym_buffer.rank_idx, dst_local_expert_idx), + dst_rank_idx) = expert_status & 0xffffffff; + ptx::atomic_add_sys( + sym_buffer.map(workspace.get_expert_recv_count_sum_ptr(dst_local_expert_idx), dst_rank_idx), + expert_status); + } + } + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Barrier before pulling + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* After the grid sync above, there is no more writes by other SMs (except 0) */ false, + /* After the NVLink barrier, there is a grid sync */ true + ); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Pull token data and SF from remote ranks into local L1 buffer + uint32_t pull_mbarrier_phase = 0; + const auto pull_buffer = smem_send_buffers.get_rank_buffer(warp_idx).get_data_buffer(0); + const auto pull_mbarrier = dispatch_barriers[warp_idx]; + + // Cache expert token counts in registers (same pattern as scheduler) + scheduler.fetch_expert_recv_count(); + + // Per-rank counts for current expert (re-loaded when expert changes) + constexpr uint32_t kNumRanksPerLane = math::constexpr_ceil_div(kNumRanks, 32u); + int current_expert_idx = -1; + uint32_t stored_rank_count[kNumRanksPerLane] = {}; + uint32_t expert_start_idx = 0, expert_end_idx = 0; + uint32_t expert_pool_block_offset = 0; + + constexpr uint32_t kNumGlobalWarps = kNumSMs * kNumDispatchWarps; + for (uint32_t token_idx = sm_idx * kNumDispatchWarps + warp_idx; ; token_idx += kNumGlobalWarps) { + // Advance expert until within the range + int old_expert_idx = current_expert_idx; + while (token_idx >= expert_end_idx) { + if (++ current_expert_idx >= kNumExpertsPerRank) + break; + + // Update pool block offset for the new expert + expert_pool_block_offset += math::ceil_div(expert_end_idx - expert_start_idx, BLOCK_M); + + // Move start and end to the next expert + expert_start_idx = expert_end_idx; + expert_end_idx += scheduler.get_num_tokens(current_expert_idx); + } + + // Finish all tokens + if (current_expert_idx >= kNumExpertsPerRank) + break; + + // Load per-rank counts when expert changes + if (old_expert_idx != current_expert_idx) { + old_expert_idx = current_expert_idx; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t j = i * 32 + lane_idx; + // TODO: this is not coalesced + stored_rank_count[i] = j < kNumRanks ? + static_cast(*workspace.get_expert_recv_count_ptr(j, current_expert_idx)) : 0; + } + } + + // Round-robin rank selection via iterative min-peeling + uint32_t current_rank_in_expert_idx; + uint32_t remaining[kNumRanksPerLane]; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] = stored_rank_count[i]; + uint32_t offset = 0; + uint32_t token_idx_in_expert = token_idx - expert_start_idx; + uint32_t slot_idx = token_idx_in_expert; + uint32_t token_idx_in_rank; + while (true) { + // Compute active count and min across all ranks + // NOTES: reduce within each lane first, then warp-reduce once + uint32_t num_actives_in_lane = 0; + uint32_t min_in_lane = 0xffffffff; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + num_actives_in_lane += remaining[i] > 0; + if (remaining[i] > 0) + min_in_lane = cute::min(min_in_lane, remaining[i]); + } + const uint32_t num_active_ranks = __reduce_add_sync(0xffffffff, num_actives_in_lane); + const uint32_t length = __reduce_min_sync(0xffffffff, min_in_lane); + + // Hit in the current round + const uint32_t num_round_tokens = length * num_active_ranks; + if (slot_idx < num_round_tokens) { + const uint32_t slot_idx_in_round = slot_idx % num_active_ranks; + uint32_t num_seen_ranks = 0; + current_rank_in_expert_idx = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) { + const uint32_t mask = __ballot_sync(0xffffffff, remaining[i] > 0); + const uint32_t num_active_lanes = __popc(mask); + if (slot_idx_in_round >= num_seen_ranks and slot_idx_in_round < num_seen_ranks + num_active_lanes) + current_rank_in_expert_idx = i * 32 + __fns(mask, 0, slot_idx_in_round - num_seen_ranks + 1); + num_seen_ranks += num_active_lanes; + } + token_idx_in_rank = offset + (slot_idx / num_active_ranks); + break; + } + + // Move into the next round + slot_idx -= num_round_tokens; + offset += length; + #pragma unroll + for (uint32_t i = 0; i < kNumRanksPerLane; ++ i) + remaining[i] -= cute::min(remaining[i], length); + } + + // Read source token-topk index (written by remote dispatch via NVLink) + const uint32_t src_token_topk_idx = *workspace.get_src_token_topk_idx_ptr( + current_expert_idx, current_rank_in_expert_idx, token_idx_in_rank); + const uint32_t src_token_idx = src_token_topk_idx / kNumTopk; + const uint32_t src_topk_idx = src_token_topk_idx % kNumTopk; + + // TMA load token from remote rank into shared memory + if (cute::elect_one_sync()) { + ptx::tma_load_1d( + pull_buffer.get_base_ptr(), + sym_buffer.map(input_token_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx), + pull_mbarrier, kHidden); + } + __syncwarp(); + + // Load and store SF (overlaps with TMA token load) + constexpr uint32_t kNumSFUint32 = kHidden / 128; + DG_STATIC_ASSERT(kNumSFUint32 > 0 and kHidden % 128 == 0, "Invalid SF"); + const auto remote_sf_ptr = sym_buffer.map( + input_sf_buffer.get_data_buffer(src_token_idx).get_base_ptr(), + current_rank_in_expert_idx); + const auto local_sf_ptr = l1_sf_buffer.get_base_ptr(); + const auto sf_pool_token_idx = expert_pool_block_offset * SF_BLOCK_M + + transform_sf_token_idx(token_idx_in_expert); + #pragma unroll + for (uint32_t i = 0; i < math::constexpr_ceil_div(kNumSFUint32, 32u); ++ i) { + const uint32_t j = i * 32 + lane_idx; + if (j < kNumSFUint32) + local_sf_ptr[j * kNumPaddedSFPoolTokens + sf_pool_token_idx] = remote_sf_ptr[j]; + } + __syncwarp(); + + // Store weights and token data + const uint32_t pool_token_idx = expert_pool_block_offset * BLOCK_M + token_idx_in_expert; + if (cute::elect_one_sync()) { + // Load weights + const auto weight = *sym_buffer.map( + input_topk_weights_buffer.get_base_ptr() + src_token_topk_idx, + current_rank_in_expert_idx); + *l1_topk_weights_buffer.get_data_buffer(pool_token_idx).get_base_ptr() = weight; + + // Wait for TMA token load to complete + ptx::mbarrier_arrive_and_set_tx(pull_mbarrier, kHidden); + ptx::mbarrier_wait_and_flip_phase(pull_mbarrier, pull_mbarrier_phase); + + // Store token to local L1 buffer via TMA + ptx::tma_store_1d( + l1_token_buffer.get_data_buffer(pool_token_idx).get_base_ptr(), + pull_buffer.get_base_ptr(), pull_buffer.get_num_bytes()); + + // Write source metadata for combine write-back + *workspace.get_token_src_metadata_ptr(pool_token_idx) = + {current_rank_in_expert_idx, src_token_idx, src_topk_idx}; + + // Wait for token TMA store to complete + cute::tma_store_arrive(); + ptx::tma_store_wait<0>(); + ptx::red_add_rel( + workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + token_idx_in_expert / BLOCK_M), 1); + } + __syncwarp(); + } + + // Clean workspace for the next usage, and also do cumulative stats + // NOTES: it is overlapped with combine reduction epilogue + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + DG_STATIC_ASSERT(kNumSMs > 1, "Invalid SM count"); + if (sm_idx == 0) { + // SM 0: clear expert send count + #pragma unroll + for (uint32_t i = thread_idx; i < kNumExperts; i += kNumDispatchThreads) + *workspace.get_expert_send_count_ptr(i) = 0; + } else { + // Other SMs: clean blocks + for (uint32_t i = sm_idx - 1; i < kNumExpertsPerRank; i += kNumSMs - 1) { + // Read expert token count before clearing + const auto num_recv_tokens = static_cast( + *workspace.get_expert_recv_count_sum_ptr(i)); + const auto num_recv_m_blocks = math::ceil_div(num_recv_tokens, BLOCK_M); + + // Compute expert pool block offset + expert_pool_block_offset = scheduler.get_pool_block_offset(i); + + // Wait read count ready + ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); + + // Clean expert token count, and add cumulative results + DG_STATIC_ASSERT(kNumDispatchWarps >= 2, "Not enough dispatch warps"); + if (warp_idx == 0) { + *workspace.get_expert_recv_count_sum_ptr(i) = 0; + } else if (warp_idx == 1) { + if (cute::elect_one_sync() and cumulative_local_expert_recv_stats != nullptr) + ptx::red_add(cumulative_local_expert_recv_stats + i, static_cast(num_recv_tokens)); + __syncwarp(); + } + + // Clean per-rank token count + for (uint32_t j = thread_idx; j < kNumRanks; j += kNumDispatchThreads) + *workspace.get_expert_recv_count_ptr(j, i) = 0; + __syncwarp(); + + // Clean L1 and L2 arrival stuffs + for (uint32_t j = thread_idx; j < num_recv_m_blocks; j += kNumDispatchThreads) { + *workspace.get_l1_arrival_count_ptr(expert_pool_block_offset + j) = 0; + *workspace.get_l2_arrival_mask_ptr(expert_pool_block_offset + j) = 0; + } + __syncwarp(); + } + } + + // Wait for all ranks to finish cleaning + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, thread_idx, + [=]() { ptx::sync_aligned(kNumDispatchThreads, kDispatchBarrierIdx); }, + /* Before the NVLink barrier, there is a grid sync */ true, + /* At the end of kernel does not need to sync */ false + ); + } else if (warp_idx == kNumDispatchWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for tokens with SFA + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_a_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts : &tensor_map_l1_acts; + const auto tensor_map_sfa_ptr = block_phase == sched::BlockPhase::Linear2 + ? &tensor_map_l2_acts_sf : &tensor_map_l1_acts_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_sfa_k = math::ceil_div(shape_k, kGranK * 4u); + + // Compute pool block offset for this expert + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + + // Wait the entire token arrival for linear 1 + if (block_phase == sched::BlockPhase::Linear1) { + const auto ptr = workspace.get_l1_arrival_count_ptr(pool_block_idx); + const auto expected = scheduler.template get_valid_m(); + while (ptx::ld_acq(ptr) != expected); + } else { + // The L1 output's block N is halved into `BLOCK_K / 2`, so we have to wait 2x L1 blocks' arrival + // NOTES: Originally we wait blocks on-demand to overlap L1 calculation + // with L2, but this optimization is negative when `num_experts_per_wave` + // guarantees L1's completion when L2 starts. So we remove it. + // In the future, if `num_experts_per_wave` is not large enough + // due to small `num_experts_per_rank`, we may need to add it back or add a switch + DG_STATIC_ASSERT(BLOCK_K == BLOCK_N, "Invalid block sizes"); + const auto ptr = workspace.get_l2_arrival_mask_ptr(pool_block_idx); + // NOTES: Equivalent to `(1ull << (2 * num_k_blocks)) - 1`, but split into two shifts + // to avoid undefined behavior when `num_k_blocks == 32` + const uint64_t expected = ((1ull << num_k_blocks) << num_k_blocks) - 1; + while (ptx::ld_acq_gpu(ptr) != expected); + } + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute token offset from pool block index + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfa_m_idx = pool_block_idx * SF_BLOCK_M; + uint32_t sfa_k_idx = k_block_idx; + + // Add 2 CTA offsets for non-leader CTA + if (not is_leader_cta) + m_idx += scheduler.template get_valid_m() / 2; + + // TMA copy tokens and SFA, then arrive at full barrier + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_a_ptr, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx, 2); + tma::copy( + tensor_map_sfa_ptr, full_barriers[stage_idx], smem_sfa[stage_idx], sfa_m_idx, sfa_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE * 2 + SF_BLOCK_M * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 1) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM TMA load warp for weights with SF + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + const auto tensor_map_b_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights : &tensor_map_l1_weights; + const auto tensor_map_sfb_ptr = + block_phase == sched::BlockPhase::Linear2 ? &tensor_map_l2_weights_sf : &tensor_map_l1_weights_sf; + + const auto shape_k = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_K : L1_SHAPE_K; + const auto shape_n = block_phase == sched::BlockPhase::Linear2 ? L2_SHAPE_N : L1_SHAPE_N; + const auto shape_sfb_k = math::ceil_div(shape_k, kGranK * 4u); + + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute weight offset + uint32_t n_idx = local_expert_idx * shape_n + n_block_idx * BLOCK_N; + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t sfb_n_idx = n_block_idx * BLOCK_N; + uint32_t sfb_k_idx = local_expert_idx * shape_sfb_k + k_block_idx; + + // TMA copy weights with SF + if (cute::elect_one_sync()) { + tma::copy( + tensor_map_b_ptr, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx, 2); + tma::copy( + tensor_map_sfb_ptr, full_barriers[stage_idx], smem_sfb[stage_idx], sfb_n_idx, sfb_k_idx, 2); + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(SMEM_B_SIZE_PER_STAGE + BLOCK_N * sizeof(uint32_t) * 2); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + __syncwarp(); + } + }); + } else if (warp_idx == kNumDispatchWarps + 2) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + // GEMM MMA issue warp (only the leader CTA will run) + if (is_leader_cta) { + // Make instruction descriptor with block scaling + // NOTES: always swap A/B + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled< + b_dtype_t, a_dtype_t, float, cutlass::float_ue8m0_t, + UMMA_M, UMMA_N, + cute::UMMA::Major::K, cute::UMMA::Major::K + >(); + auto sf_desc = mma::sm100::make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = mma::sm100::make_umma_desc(smem_a[0], 0, 0); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Dynamic update of UMMA N based on effective M + mma::sm100::update_instr_desc_with_umma_n(instr_desc, scheduler.template get_valid_m()); + + // Wait tensor memory empty barrier arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase ^ 1); + ptx::tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + constexpr uint16_t kCTAMask = (1 << 2) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + __syncwarp(); + }; + + // Launch MMAs + #pragma unroll 2 + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA load completion + full_barriers[stage_idx]->wait(phase); + ptx::tcgen05_after_thread_sync(); + + const auto a_desc_base_lo = ptx::exchange(a_desc_lo, stage_idx); + const auto b_desc_base_lo = ptx::exchange(b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + // UTCCP copy SFA and SFB to TMEM + using cute_utccp_t = cute::SM100_UTCCP_4x32dp128bit_2cta; + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + mma::sm100::replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + + // Issue UMMA + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const auto runtime_instr_desc = + mma::sm100::make_runtime_instr_desc_with_sf_id(instr_desc, k, k); + a_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_M, kSwizzleAMode, a_dtype_t>(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = mma::sm100::advance_umma_desc_lo< + cute::UMMA::Major::K, LOAD_BLOCK_N, kSwizzleBMode, b_dtype_t>(b_desc_base_lo, 0, k * UMMA_K); + ptx::SM100_MMA_MXF8F6F4_2x1SM_SS::fma( + b_desc, a_desc, accum_stage_idx * UMMA_N, + k_block_idx > 0 or k > 0, runtime_instr_desc, + kTmemStartColOfSFB, kTmemStartColOfSFA); + } + } + __syncwarp(); + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_k_blocks - 1); + } + }); + + // To safely deconstruct barriers, we need another round of waits + if (current_iter_idx > 0) { + const auto accum_phase_idx = ((current_iter_idx - 1) / kNumEpilogueStages) & 1; + tmem_empty_barriers[(current_iter_idx - 1) % kNumEpilogueStages]->wait(accum_phase_idx); + } + } + } else if (warp_idx == kNumDispatchWarps + 3) { + // Adjust registers + cutlass::arch::warpgroup_reg_dealloc(); + + } else if (warp_idx >= kNumDispatchWarps + kNumMMANonEpilogueWarps) { + // Adjust registers + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // GEMM epilogue warps + const auto epilogue_warp_idx = warp_idx - (kNumDispatchWarps + kNumMMANonEpilogueWarps); + const auto epilogue_wg_idx = epilogue_warp_idx / 4; + const auto epilogue_thread_idx = epilogue_warp_idx * 32 + lane_idx; + const auto warp_idx_in_wg = epilogue_warp_idx % 4; + DG_STATIC_ASSERT((kNumDispatchWarps + kNumMMANonEpilogueWarps) % 4 == 0 and + kNumEpilogueWarps % 4 == 0, "Invalid epilogue warps"); + + // TODO: support effective block M + // NOTES: + // - 2 warpgroups divide the whole BM into BM / 2 + // - 4 warps divide the whole BN into BN / 4 + // - BM / 2 is further divided into stored blocks, i.e. with `STORE_BLOCK_M` size + // - `STORE_BLOCK_M` in further divided into `ATOM_M` + constexpr uint32_t WG_BLOCK_M = BLOCK_M / kNumEpilogueWarpgroups; + constexpr uint32_t ATOM_M = 8; + constexpr uint32_t kNumBankGroupBytes = 16u; + constexpr uint32_t kNumAtomsPerStore = STORE_BLOCK_M / ATOM_M; + DG_STATIC_ASSERT(BLOCK_M % kNumEpilogueWarpgroups == 0, "Invalid block M"); + DG_STATIC_ASSERT(WG_BLOCK_M % STORE_BLOCK_M == 0, "Invalid warpgroup block M"); + DG_STATIC_ASSERT(STORE_BLOCK_M % ATOM_M == 0, "Invalid store block M"); + DG_STATIC_ASSERT(BLOCK_N == 128, "Invalid block N"); + + // Ensure the epilogue barrier cannot run with the pull barrier + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Persistently schedule over blocks + uint32_t current_iter_idx = 0; + scheduler.for_each_block([&](const sched::BlockPhase& block_phase, + const uint32_t& local_expert_idx, + const uint32_t& num_k_blocks, + const uint32_t& m_block_idx, const uint32_t& n_block_idx) { + // Wait UMMA arrival + const auto accum_stage_idx = current_iter_idx % kNumEpilogueStages; + const auto accum_phase = (current_iter_idx ++ / kNumEpilogueStages) & 1; + tmem_full_barriers[accum_stage_idx]->wait(accum_phase); + ptx::tcgen05_after_thread_sync(); + + // Compute offsets + // NOTES: use shuffle here to let NVCC know warp divergence won't happen + const uint32_t valid_m = ptx::exchange(scheduler.template get_valid_m(), 0); + const uint32_t pool_block_idx = scheduler.get_current_pool_block_offset() + m_block_idx; + uint32_t m_idx = pool_block_idx * BLOCK_M; + uint32_t n_idx = n_block_idx * BLOCK_N; + + if (block_phase == sched::BlockPhase::Linear1) { + // Unified L1 epilogue: SwiGLU in-place using granularity 8 interleaved weights + // With `SM100_TMEM_LOAD_16dp256b1x`, gate/up pairs are: + // (values[0], values[2]), (values[1], values[3]), + // (values[4], values[6]), (values[5], values[7]) + float stored_cached_weight = 0; + + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + // Iterate all atoms in the store block + float2 swiglu_values[kNumAtomsPerStore * 2]; + float2 amax_values[kNumAtomsPerStore]; + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + const uint32_t j = s * kNumAtomsPerStore + i; + + // Load weights from global into register cache per 32 tokens + DG_STATIC_ASSERT(32 % ATOM_M == 0, "Invalid block size"); + if ((j * ATOM_M) % 32 == 0 and (WG_BLOCK_M % 32 == 0 or j * ATOM_M + lane_idx < WG_BLOCK_M)) { + stored_cached_weight = *l1_topk_weights_buffer + .get_data_buffer(m_idx + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M + lane_idx) + .get_base_ptr(); + } + + // Load weights from register cache + const float2 weights = { + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 0), + ptx::exchange(stored_cached_weight, (j * ATOM_M) % 32 + (lane_idx % 4) * 2 + 1) + }; + + // Load from TMEM + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + j * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Signal tensor memory consumed on the last atom + if (j == WG_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Apply SwiGLU: silu(gate) * up + // Gate/up pairs: (0, 2), (1, 3), (4, 6), (5, 7) + auto fp32_values = reinterpret_cast(values); + #pragma unroll + for (uint32_t k = 0; k < 2; ++ k) { + auto bf16_gate = __float22bfloat162_rn(make_float2(fp32_values[k * 4], fp32_values[k * 4 + 1])); + auto bf16_up = __float22bfloat162_rn(make_float2(fp32_values[k * 4 + 2], fp32_values[k * 4 + 3])); + + // Clamp + if constexpr (kActivationClamp != cute::numeric_limits::infinity()) { + bf16_gate = __hmin2(bf16_gate, {kActivationClamp, kActivationClamp}); + bf16_up = __hmax2(bf16_up, {-kActivationClamp, -kActivationClamp}); + bf16_up = __hmin2(bf16_up, {kActivationClamp, kActivationClamp}); + } + + // SwiGLU + auto gate = __bfloat1622float2(bf16_gate); + auto neg_gate_exp = make_float2( + kFastMath ? __expf(-gate.x) : expf(-gate.x), + kFastMath ? __expf(-gate.y) : expf(-gate.y)); + const auto denom = __fadd2_rn({1.0f, 1.0f}, neg_gate_exp); + if constexpr (kFastMath) { + gate = __fmul2_rn(gate, {math::fast_rcp(denom.x), math::fast_rcp(denom.y)}); + } else { + gate = {gate.x / denom.x, gate.y / denom.y}; + } + const auto up = __bfloat1622float2(bf16_up); + swiglu_values[i * 2 + k] = __fmul2_rn(__fmul2_rn(gate, up), weights); + } + + // Amax reduction + amax_values[i].x = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].x), cute::abs(swiglu_values[i * 2 + 1].x)), + math::ReduceMax()); + amax_values[i].y = math::warp_reduce<4, true>( + cute::max(cute::abs(swiglu_values[i * 2 + 0].y), cute::abs(swiglu_values[i * 2 + 1].y)), + math::ReduceMax()); + if (lane_idx < 4) + smem_amax_reduction[epilogue_warp_idx * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx] = amax_values[i]; + __syncwarp(); + } + + // Wait shared memory release from previous TMA store + // And fence `smem_amax_reduction` + const uint32_t tma_stage_idx = s % kNumTMAStoreStages; + ptx::tma_store_wait(); + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Cast to FP8 E4M3 and store into shared memory + #pragma unroll + for (uint32_t i = 0; i < kNumAtomsPerStore; ++ i) { + // Reduce amax + const float2 wp_amax = + smem_amax_reduction[(epilogue_warp_idx ^ 1) * (STORE_BLOCK_M / 2) + i * (ATOM_M / 2) + lane_idx % 4]; + amax_values[i].x = cute::max(amax_values[i].x, wp_amax.x); + amax_values[i].y = cute::max(amax_values[i].y, wp_amax.y); + + // Calculate SF + float2 sf, sf_inv; + math::get_e4m3_sf_and_sf_inv(amax_values[i], sf, sf_inv); + + // Cast + const float2 upper = __fmul2_rn(swiglu_values[i * 2 + 0], sf_inv); + const float2 lower = __fmul2_rn(swiglu_values[i * 2 + 1], sf_inv); + const auto fp8x4_values = __nv_fp8x4_e4m3(make_float4(upper.x, upper.y, lower.x, lower.y)); + + // STSM + uint32_t row = lane_idx; + uint32_t col = warp_idx_in_wg; + const auto smem_ptr = smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N + + i * ATOM_M * L1_OUT_BLOCK_N + + row * L1_OUT_BLOCK_N + + (col ^ (row / 2)) * kNumBankGroupBytes; + ptx::SM100_U8x4_STSM_T<__nv_fp8x4_e4m3>::copy(fp8x4_values, smem_ptr); + + // Store SF to `l2_sf_buffer` as UE8M0 (MN-major layout) + // Only one warp per pair writes (both hold the same SF after cross-warp reduce) + // Each lane < 4 holds SF for 2 rows (sf.x and sf.y) + if (warp_idx_in_wg % 2 == 0 and lane_idx < 4) { + const uint32_t k_idx = n_block_idx * 2 + warp_idx_in_wg / 2; + const uint32_t k_uint_idx = k_idx / 4, byte_idx = k_idx % 4; + const uint32_t mn_stride = kNumPaddedSFPoolTokens * sizeof(uint32_t); + const auto sf_base_ptr = l2_sf_buffer.get_base_ptr(); + // NOTES: consecutive tokens (t, t + 1) are in the same 32-group, so `sf_idx` differs by 4 + // NOTES: originally there was: + // - `const uint32_t token_idx_in_expert = m_block_idx * BLOCK_M + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2 + // - `scheduler.get_current_pool_block_offset() * SF_BLOCK_M + transform_sf_token_idx(token_idx_in_expert)` + // We find out that + // 1. `m_block_idx * BLOCK_M` mod `BLOCK_M` is 0, and `epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M + lane_idx * 2` is always < `BLOCK_M`, so we can put `m_block_idx * BLOCK_M` outside + // 2. `lane_idx * 2` controls the lowest 3 bit of `token_idx_in_expert`, and `transform_sf_token_idx` is a bitwise-independent transformation if the input is less than `BLOCK_M`, so we can put `lane_idx * 2` outside + // This reduce the number of computation instructions. + const uint32_t token_base_idx = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + __builtin_assume(token_base_idx < BLOCK_M); + const auto sf_pool_token_idx = scheduler.get_current_pool_block_offset() * SF_BLOCK_M + + m_block_idx * SF_BLOCK_M + transform_sf_token_idx(token_base_idx) + (lane_idx * 2) * 4; + const auto sf_addr = k_uint_idx * mn_stride + sf_pool_token_idx * static_cast(sizeof(uint32_t)) + byte_idx; + sf_base_ptr[sf_addr] = + (*reinterpret_cast(&sf.x) >> 23); + sf_base_ptr[sf_addr + 4 * static_cast(sizeof(uint32_t))] = + (*reinterpret_cast(&sf.y) >> 23); + } + __syncwarp(); + } + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Issue TMA store after all atoms in this store block + if (warp_idx_in_wg == 0 and cute::elect_one_sync()) { + uint32_t out_n_idx = n_block_idx * L1_OUT_BLOCK_N; + cute::tma_store_fence(); + cute::SM90_TMA_STORE_2D::copy( + &tensor_map_l1_output, + smem_cd[tma_stage_idx] + epilogue_wg_idx * STORE_BLOCK_M * L1_OUT_BLOCK_N, + out_n_idx, + m_idx + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M); + cute::tma_store_arrive(); + } + __syncwarp(); + } + + // Notify L2 + // TODO: less epilogue sync scope + ptx::tma_store_wait<0>(); + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(L2_SHAPE_K <= 64 * L1_OUT_BLOCK_N, "L2 shape K is too large"); + ptx::red_or_rel_gpu( + workspace.get_l2_arrival_mask_ptr(pool_block_idx), + 1ull << n_block_idx + ); + } + __syncwarp(); + } else { + DG_STATIC_ASSERT(STORE_BLOCK_M % 8 == 0, "Invalid store M"); + constexpr uint32_t kNumRowsPerWarp = STORE_BLOCK_M / 8; + + // L2 BF16 epilogue: write GEMM output to remote combine buffer via NVLink + #pragma unroll + for (uint32_t s = 0; s < WG_BLOCK_M / STORE_BLOCK_M; ++ s) { + // Early break if the entire store block is beyond the valid token range + // TODO: check performance + if (epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M >= valid_m) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + break; + } + + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_M / ATOM_M; ++ i) { + // Load from TMEM using .16x256b shape to satisfy STSM layout requirements + // Start from lane index 0 and 16 + uint32_t tmem_addr = accum_stage_idx * UMMA_N + epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + i * ATOM_M; + uint32_t values[ATOM_M]; + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cute::SM100_TMEM_LOAD_16dp256b1x::copy(tmem_addr | 0x00100000, + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + + // Wait shared memory release from previous NVLink store + // NOTES: skip for the first store block since the prior full barrier already ensures completion + if (i == 0 and s > 0) + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Signal tensor memory consumed + if (s == WG_BLOCK_M / STORE_BLOCK_M - 1 and i == STORE_BLOCK_M / ATOM_M - 1) { + ptx::tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Store into shared memory + // NOTES: only use first 16 lanes for address + // NOTES: 2 warps share a BF16 swizzle atom + uint32_t row = lane_idx % 8; + uint32_t col = (epilogue_warp_idx % 2) * 4 + lane_idx / 8; + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (warp_idx_in_wg / 2) * STORE_BLOCK_M * kSwizzleCDMode + + i * ATOM_M * kSwizzleCDMode + + row * (kNumBankGroupBytes * 8) + + (col ^ row) * kNumBankGroupBytes; + ptx::SM90_U32x4_STSM_T::copy( + math::cast_into_bf16_and_pack(values[0], values[1]), + math::cast_into_bf16_and_pack(values[2], values[3]), + math::cast_into_bf16_and_pack(values[4], values[5]), + math::cast_into_bf16_and_pack(values[6], values[7]), + smem_ptr + ); + } + + // Wait shared memory ready + ptx::sync_aligned(128, kEpilogueWGBarrierStartIdx + epilogue_wg_idx); + + // Write into remote buffers + // One warp per row, now the layout is different from shared memory storing + const uint32_t row_in_atom = (warp_idx_in_wg * 2 + lane_idx / 16) % ATOM_M; + const uint32_t bank_group_idx = lane_idx % 8; + + #pragma unroll + for (uint32_t j = 0; j < kNumRowsPerWarp; ++ j) { + const uint32_t row_in_store = j * 8 + warp_idx_in_wg * 2 + lane_idx / 16; + const uint32_t m_idx_in_block = epilogue_wg_idx * WG_BLOCK_M + s * STORE_BLOCK_M + row_in_store; + + // Skip padding rows beyond the actual token count for this expert + if (m_idx_in_block >= valid_m) + break; + + const auto src_metadata = *workspace.get_token_src_metadata_ptr(m_idx + m_idx_in_block); + const uint32_t dst_rank_idx = src_metadata.rank_idx; + const uint32_t dst_token_idx = src_metadata.token_idx; + const uint32_t dst_topk_idx = src_metadata.topk_idx; + + // Read from shared memory + const auto smem_ptr = smem_cd_l2 + + epilogue_wg_idx * STORE_BLOCK_M * BLOCK_N * static_cast(sizeof(nv_bfloat16)) + + (lane_idx % 16 / 8) * STORE_BLOCK_M * kSwizzleCDMode + + row_in_store * kSwizzleCDMode + + (bank_group_idx ^ row_in_atom) * kNumBankGroupBytes; + const auto packed = ptx::ld_shared(reinterpret_cast(smem_ptr)); + + // Write into remote + const auto dst_token = combine_token_buffer.get_rank_buffer(dst_topk_idx) + .get_data_buffer(dst_token_idx); + const auto dst_ptr = math::advance_ptr( + dst_token.get_base_ptr(), + n_idx * static_cast(sizeof(nv_bfloat16)) + (lane_idx % 16) * static_cast(sizeof(float4))); + *sym_buffer.map(dst_ptr, dst_rank_idx) = packed; + } + } + + // Ensure the next epilogue safe to use shared memory + ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); + } + }); + + // Deallocate tensor memory + // NOTES: must be called by the same logical warp ID on both CTAs + if (epilogue_warp_idx == 0) + Allocator().free(0, kNumTmemCols); + + // NVLink barrier (grid sync + cross-rank signal + grid sync): ~4 us + comm::nvlink_barrier( + workspace, sym_buffer, sm_idx, epilogue_thread_idx, + [&]() { ptx::sync_aligned(kNumEpilogueThreads, kEpilogueFullBarrierIdx); } + ); + + // Barrier with dispatch warps, so that they can do clean workspace + ptx::sync_unaligned(kNumDispatchThreads + kNumEpilogueThreads, kDispatchWithEpilogueBarrierIdx); + + // Combine: reduce top-k results and write back + // NOTES: reuse shared memory from start up to the barriers + // 1 token, 1 topk latency: ~3 us + constexpr uint32_t kNumHiddenBytes = kHidden * sizeof(nv_bfloat16); + constexpr uint32_t kNumElemsPerUint4 = sizeof(uint4) / sizeof(nv_bfloat162); + + // 3 slots of chunk is needed: 2 load stages and 1 store + constexpr uint32_t kNumChunkSlots = 3; + constexpr uint32_t kNumMaxRegistersForBuffer = 128; + + // NOTES: either 1 or 2 chunks for simplicity + // NOTES: Restrict on both smem and register + constexpr uint32_t kNumChunks = + kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes <= SMEM_BEFORE_BARRIER_SIZE and kHidden <= 32 * kNumMaxRegistersForBuffer ? 1 : 2; + constexpr uint32_t kNumChunkBytes = kNumHiddenBytes / kNumChunks; + constexpr uint32_t kNumChunkUint4 = kNumChunkBytes / sizeof(uint4); + constexpr uint32_t kNumUint4PerLane = kNumChunkUint4 / 32; + DG_STATIC_ASSERT(kHidden % kNumChunks == 0, "Hidden must be divisible by number of chunks"); + DG_STATIC_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumHiddenBytes / kNumChunks <= SMEM_BEFORE_BARRIER_SIZE, "Hidden is too large"); + DG_STATIC_ASSERT(kNumChunkBytes % 16 == 0, "Combine chunk must be TMA-aligned (16 bytes)"); + DG_STATIC_ASSERT(kNumChunkBytes % sizeof(uint4) == 0, "Combine chunk must be divisible by 16 bytes"); + DG_STATIC_ASSERT(kNumChunkUint4 % 32 == 0, "Combine chunk must be a multiple of 32 16-byte elements (one per lane)"); + DG_STATIC_ASSERT(kNumTopk <= 32, "Top-k must fit in a single warp"); + + // Verify combined shared memory budget at runtime + DG_DEVICE_ASSERT(kNumChunkSlots * kNumEpilogueWarps * kNumChunkBytes <= static_cast( + reinterpret_cast(barrier_start_ptr) - smem_buffer)); + + // Per-warp buffer: 2 stage load buffers + 1 store buffer + const auto combine_load_buffer = utils::PatternVisitor([&](const uint32_t& i) { + return math::advance_ptr(smem_buffer, (epilogue_warp_idx + i * kNumEpilogueWarps) * kNumChunkBytes); + }); + const auto combine_store_buffer = math::advance_ptr(smem_buffer, (epilogue_warp_idx + kNumEpilogueWarps * 2) * kNumChunkBytes); + + // Per-warp barriers + auto combine_load_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return combine_barriers[i + epilogue_warp_idx * 2]; + }); + + // Iterate over all tokens + uint32_t combine_phase = 0; + uint32_t load_stage_idx = 0; + for (uint32_t token_idx = sm_idx * kNumEpilogueWarps + epilogue_warp_idx; + token_idx < num_tokens; + token_idx += kNumSMs * kNumEpilogueWarps) { + // Read top-k slot indices: each lane reads one slot, then broadcast via exchange + DG_STATIC_ASSERT(kNumTopk <= 32, "Invalid number of topk"); + const int stored_topk_slot_idx = lane_idx < kNumTopk ? + static_cast(__ldg(input_topk_idx_buffer.get_base_ptr() + token_idx * kNumTopk + lane_idx)) : -1; + const uint32_t total_mask = __ballot_sync(0xffffffff, stored_topk_slot_idx >= 0); + + // Iterate all chunks + for (uint32_t chunk = 0; chunk < kNumChunks; ++ chunk) { + const uint32_t chunk_byte_offset = chunk * kNumChunkBytes; + + // Move mask and load + uint32_t mask = total_mask; + const auto move_mask_and_load = [&](const uint32_t& i) { + if (mask) { + // Move + const uint32_t slot_idx = __ffs(mask) - 1; + mask ^= 1 << slot_idx; + + // Load + if (cute::elect_one_sync()) { + const auto src_ptr = math::advance_ptr( + combine_token_buffer.get_rank_buffer(slot_idx) + .get_data_buffer(token_idx).get_base_ptr(), + chunk_byte_offset); + ptx::tma_load_1d(combine_load_buffer[i], src_ptr, combine_load_barriers[i], kNumChunkBytes); + ptx::mbarrier_arrive_and_set_tx(combine_load_barriers[i], kNumChunkBytes); + } + __syncwarp(); + return true; + } + return false; + }; + + // Load the first selection + bool do_reduce = move_mask_and_load(load_stage_idx); + + // Accumulate all top-k contributions for this chunk in float registers + float2 reduced[kNumUint4PerLane * kNumElemsPerUint4] = {}; + while (do_reduce) { + // Prefetch next top-k into the buffer while current is being accumulated + do_reduce = move_mask_and_load(load_stage_idx ^ 1); + + // Accumulate + combine_load_barriers[load_stage_idx]->wait(combine_phase); + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + const auto uint4_values = combine_load_buffer[load_stage_idx][j * 32 + lane_idx]; + const auto bf16_values = reinterpret_cast(&uint4_values); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + ptx::accumulate(reduced[j * kNumElemsPerUint4 + l], bf16_values[l]); + } + combine_phase ^= load_stage_idx; + load_stage_idx ^= 1; + } + + // Cast + #pragma unroll + for (uint32_t j = 0; j < kNumUint4PerLane; ++ j) { + uint4 casted; + auto casted_bf16 = reinterpret_cast(&casted); + #pragma unroll + for (uint32_t l = 0; l < kNumElemsPerUint4; ++ l) + casted_bf16[l] = __float22bfloat162_rn(reduced[j * kNumElemsPerUint4 + l]); + + // Wait share memory release and write + if (j == 0) { + ptx::tma_store_wait<0>(); + __syncwarp(); + } + ptx::st_shared(combine_store_buffer + j * 32 + lane_idx, + casted.x, casted.y, casted.z, casted.w); + } + __syncwarp(); + + // TMA store the token chunk + if (cute::elect_one_sync()) { + cute::tma_store_fence(); + ptx::tma_store_1d( + math::advance_ptr(y, static_cast(token_idx) * kNumHiddenBytes + chunk_byte_offset), + combine_store_buffer, kNumChunkBytes); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..7ce008e5 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -0,0 +1,567 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4); + const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + if (kNumMulticast > 1) + cute::cluster_sync(); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad))); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx)); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32; + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + } + + // Deallocate tensor memory + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + if (warp_idx == 0) + Allocator().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh new file mode 100644 index 00000000..e6744f59 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -0,0 +1,403 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint32_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + logits_dtype_t* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); + + // Types + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + + // Align to 512 bytes for swizzle-64B + extern __shared__ __align__(512) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + + // TMA configs + constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + + // Tensor memory allocation + auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); + + // Initialize barriers + DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 1) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block scheduler + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + // TMA alignment requirements for SF KV + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx == kSpecWarpStart) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Prefetch + const auto issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else if (warp_idx == kSpecWarpStart + 1) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue UMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); + ptx::tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = mma::sm100::make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = mma::sm100::make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + } + num_total_kv_blocks += num_kv_blocks; + + // UMMA warp must also arrive on empty_q to prevent running ahead + // of math warps in the Q pipeline + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto tmem_start = warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Local register buffers + float weights[BLOCK_Q][kNumHeads]; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); + ptx::tcgen05_after_thread_sync(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto kv_offset = kv_start + kv_block_idx * BLOCK_KV + math_thread_idx; + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + // Load accumulator from TMEM + float accum[kNumHeads]; + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Release TMEM empty + if (i == BLOCK_Q - 1) { + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + } + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(scale_kv * (sum.x + sum.y)); + + // Store into the global memory + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset and kv_offset < seq_k_end[i]) + logits[q_offset + kv_offset - seq_k_start[i]] = result; + } else { + logits[q_offset + kv_offset] = result; + } + __syncwarp(); + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh new file mode 100644 index 00000000..9a5bddbf --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -0,0 +1,439 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Utils + const auto sm_idx = blockIdx.x; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + constexpr uint32_t kSpecWarpStart = kNumMathWarpGroups * 4; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kSpecWarpStart) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + + // For non-varlen odd kNextN >= 3, pad to even using TMA OOB zero-fill. + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + static constexpr uint32_t kNumNextNAtoms = math::constexpr_ceil_div(kNextN, kNextNAtom); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextNAtom * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextNAtom * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); + }); + constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = utils::PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + + constexpr uint32_t kNumTmemCols = kNextNAtom * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Initialize barriers + if (warp_idx == kSpecWarpStart and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads + 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + cutlass::arch::fence_barrier_init(); + } + if (warp_idx == kSpecWarpStart + 1) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 56; + constexpr uint32_t kNumMathRegisters = 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + using Scheduler = sched::PagedMQALogitsScheduler; + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Q and KV pipeline + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = kNextNAtom * kNumHeads; + DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + + if (warp_idx == kSpecWarpStart) { + // TMA warp for loading data + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& tma_q_atom_idx) { + if (cute::elect_one_sync()) { + const auto q_token_idx = Scheduler::atom_to_token_idx(tma_q_atom_idx); + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_token_idx * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_token_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx, num_kv; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_atom_idx), q_iter_idx = 1; + + uint32_t kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when (q, atom) changes + const auto next_advance = scheduler.get_atom_advance(next_q_atom_idx, batch_size); + bool prefetch_q = (q_atom_idx != next_q_atom_idx) and scheduler.exist_q_atom_idx(next_q_atom_idx + next_advance); + + if (q_atom_idx != next_q_atom_idx) + kv_block_idx_ptr = 32; + + q_atom_idx = next_q_atom_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Read KV block index + // TODO(xuzhean): consider -1 + if (kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + const auto block_table_offset = Scheduler::atom_to_block_table_row(q_atom_idx) * static_cast(block_table_stride); + kv_block_idx_storage = (kv_idx + lane_idx < num_kv) + ? block_table[block_table_offset + kv_idx + lane_idx] : 0; + } + __syncwarp(); + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_atom_idx + next_advance); + } + + uint32_t kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumBlocksPerSplit; ++ i) { + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv); + } + } else if (warp_idx == kSpecWarpStart + 1) { + cutlass::arch::warpgroup_reg_dealloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ptx::ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 1; + + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + if (q_atom_idx != next_q_atom_idx) { + // Release previous Q empty (UMMA warp must participate to prevent + // running ahead of math warps in the Q pipeline) + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + } + + q_atom_idx = next_q_atom_idx; + kv_idx = next_kv_idx; + + // Wait KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(umma_phase); + ptx::tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = mma::sm100::make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = mma::sm100::make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + umma_phase ^= 1; + } + } else if (warp_idx == kSpecWarpStart + 2 or warp_idx == kSpecWarpStart + 3) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kSpecWarpStart) { + // Math warpgroups for reduce + cutlass::arch::warpgroup_reg_alloc(); + auto scheduler = Scheduler(sm_idx, batch_size, context_lens, schedule_meta, indices); + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + // Offsets + const auto math_warpgroup_idx = warpgroup_idx; + const auto tmem_start = math_warpgroup_idx * UMMA_N; + const auto math_thread_idx = warp_idx * 32 + lane_idx; + + // Helper lambda for loading tensor memory + auto tmem_load = [](auto num_elems_c, const uint32_t& tmem_addr, float* accum) { + constexpr int N = decltype(num_elems_c)::value; + DG_STATIC_ASSERT(N == 32 or N == 64, "Unsupported TMEM load size"); + using Loader = cute::conditional_t; + [&](cute::index_sequence) { + Loader::copy(tmem_addr, reinterpret_cast(accum)[Is]...); + }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + }; + + // Local register buffers + float weights[kNextNAtom][kNumHeads]; + + // Initialize outside valid range to indicate no previous task + uint32_t q_atom_idx = batch_size * kNumNextNAtoms, kv_idx; + uint32_t next_q_atom_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 0; + bool is_paired_atom = false; + + while (scheduler.fetch_next_task(next_q_atom_idx, next_kv_idx, next_num_kv)) { + // Q or atom changes + if (q_atom_idx != next_q_atom_idx) { + // Release last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextNAtom; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + + if constexpr (kIsVarlen) { + is_paired_atom = (scheduler.get_atom_advance(next_q_atom_idx, batch_size) == 2); + } + } + + // Get current task indices + q_atom_idx = next_q_atom_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = Scheduler::atom_to_token_idx(q_atom_idx) * static_cast(logits_stride) + kv_idx * BLOCK_KV; + + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + math_thread_idx); + + // Wait UMMA arrival + full_umma_barriers[math_warpgroup_idx]->wait(umma_phase); + ptx::tcgen05_after_thread_sync(); + umma_phase ^= 1; + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + + const auto reduce_and_store = [&](auto num_iters_c) { + constexpr uint32_t kNumIters = decltype(num_iters_c)::value; + float accum[kNumHeads]; + + #pragma unroll + for (uint32_t i = 0; i < kNumIters; ++ i) { + // Load accumulator from TMEM + tmem_load(cute::Int{}, tmem_start + i * kNumHeads, accum); + + // Accumulate weighted ReLU in parallel + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto transform = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (uint32_t j = 0; j < kNumHeads; j += 4) { + sum_0 = transform(j, sum_0); + sum_1 = transform(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + auto result = static_cast(scale_kv * (sum.x + sum.y)); + + // Store into the global memory + logits[kv_offset + i * static_cast(logits_stride) + math_thread_idx] = result; + __syncwarp(); + } + + // Release TMEM empty + ptx::tcgen05_before_thread_sync(); + empty_umma_barriers[math_warpgroup_idx]->arrive(); + }; + + if constexpr (kIsVarlen) { + if (is_paired_atom) + reduce_and_store(cute::Int{}); + else + reduce_and_store(cute::Int<1>{}); + } else if constexpr (kPadOddN) { + if (q_atom_idx % kNumNextNAtoms == kNumNextNAtoms - 1) + reduce_and_store(cute::Int<1>{}); + else + reduce_and_store(cute::Int{}); + } else { + reduce_and_store(cute::Int{}); + } + } + + // Free tensor memory + cutlass::arch::NamedBarrier(kNumMathThreads, 0).sync(); + if (warp_idx == 0) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..aaf7fd9a --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,350 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_DEVICE +uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { + // Calculate the index of the bank group to be written in the atom + const auto bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` + // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)` + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % (kSwizzleMode / kSwizzleBase); + + return row * 128 + col * kSwizzleBase; +} + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t kNumCastStages = 2; + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + constexpr auto kMajorA = cute::UMMA::Major::K; + constexpr auto kMajorB = cute::UMMA::Major::K; + DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages"); + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = utils::get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 4 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + full_cast_barriers[i]->init(kNumCastAndReduceThreads); + empty_barriers[i]->init(1); + empty_cast_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Dispatch warps into different roles + if (warp_idx < kNumMMAThreads / 32) { + // TMA load warp + if (warp_idx == 0 and cute::elect_one_sync()) { + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } + + // MMA issue warp + if (warp_idx == 1) { + // Make instruction descriptor + constexpr uint32_t UMMA_M = BLOCK_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(float); + constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float); + using umma_t = cute::SM100_MMA_TF32_TS; + auto instr_desc = cute::UMMA::make_instr_desc(); + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto b_desc = mma::sm100::make_umma_desc(smem_b[0], 0, 0); + const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Launch MMAs + // We can not unroll this part + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + const auto& cast_stage_idx = s % kNumCastStages; + full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); + ptx::tcgen05_after_thread_sync(); + + // Issue UMMA + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; + const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; + const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; + b_desc.lo = mma::sm100::advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); + } + + // Commit + cutlass::arch::umma_arrive(reinterpret_cast(empty_cast_barriers[cast_stage_idx])); + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + + // Commit to epilogue threads + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Only support layout F (M = 64) and D (M = 128) + DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + ptx::tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Source and destination memory address + uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd) + // Base pointer + warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset + get_swizzled_smem_offset(i, lane_idx); // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) + ptx::st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + if constexpr (BLOCK_M == 64) + __syncwarp(); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0); + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } else { + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32; + + // TODO: make even larger block K + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + + // Launch reductions + float2 sum[2] = {float2{0, 0}, float2{0, 0}}; + #pragma unroll kNumStages + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b) + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + const auto& smem_base_ptr = reinterpret_cast(smem_a[stage_idx]) + // Base pointer + sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset + + // 4 lanes shared a bank group + uint32_t uint32_values[2][kNumLoads]; + DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads"); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; i += 2) { + auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); + ptx::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); + } + + // Wait tensor memory empty + const auto& cast_stage_idx = s % kNumCastStages; + empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1); + + // Cast, reduce and store into tensor memory + float2 fp32x2_values[2][kNumLoads]; + const auto& upper_view = reinterpret_cast(&fp32x2_values[0]); + const auto& lower_view = reinterpret_cast(&fp32x2_values[1]); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast(&uint32_values[u][i])); + sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]); + } + + // Store upper and lower part at the same time + const auto idx_0 = i * 2, idx_1 = i * 2 + 1; + cute::SM100_TMEM_STORE_16dp256b1x::copy( + upper_view[idx_0], upper_view[idx_1], + lower_view[idx_0], lower_view[idx_1], + cast_stage_idx * BLOCK_K + i * 8); + } + cutlass::arch::fence_view_async_tmem_store(); + + // Arrive for issuing MMAs + ptx::tcgen05_before_thread_sync(); + full_cast_barriers[cast_stage_idx]->arrive(); + } + + // Intra-warp reduction and write back + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + const auto reduced_sum = math::warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + if (lane_idx % 4 == 0 and m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum; + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh new file mode 100644 index 00000000..84a149eb --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -0,0 +1,388 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `warpgroup_wait<0>()` overhead + constexpr uint32_t kDoMergeStages = + kNumStages_ >= 10 and + kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and + kNumMathThreads == 128; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 5; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + // Types + using WGMMA = typename mma::sm90::BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B/D must be aligned to 1024 bytes"); + + // D/A/B shared memory + auto smem_d = reinterpret_cast(smem_buffer); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 48; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), sched::IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), sched::IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma::copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma::copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma::copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = mma::sm90::make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = mma::sm90::make_gmma_desc(smem_b[0], 0, 0); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // TODO: remove some useless computation for unaligned Ms + const auto num_total_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + const uint32_t atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( + a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); + b_desc.reg32_[0] = mma::sm90::advance_gmma_desc_lo( + b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); + WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); + } + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + if constexpr (cute::is_same_v) { + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + ptx::SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + } else { + // Use `st.shared` if STSM is not available + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + ptx::st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + ptx::st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + } + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + // Use TMA store to write back to global memory + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), sched::IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh new file mode 100644 index 00000000..7c344296 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -0,0 +1,183 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + float *d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename mma::sm90::BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = ptx::get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); + DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); + DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + // Fill shared memory pointers + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block indices + const uint32_t num_n_blocks = math::ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * math::ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + #pragma unroll + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); + + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t k_idx = sk_idx % SHAPE_K; + const uint32_t s_idx = sk_idx / SHAPE_K; + + constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); + tma::copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); + tma::copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + float accum[WGMMA::kNumAccum] = {0}; + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrivals + const auto stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, 1); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + empty_barriers[stage_idx]->arrive(); + } + + const auto row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + if (col + i * 8 >= SHAPE_N) + break; + if (row < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 0) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 0], accum[i * 4 + 1])); + } + if (row + 8 < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 8) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 2], accum[i * 4 + 3])); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh new file mode 100644 index 00000000..32096250 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -0,0 +1,346 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, + int* grouped_layout, + cute::TmaDescriptor* tensor_map_buffer, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_b_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); + + // Types + using WGMMA = typename mma::sm90::FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 2 : 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = threadIdx.x % 32; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a_base); + cute::prefetch_tma_descriptor(&tensor_map_b_base); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Tensor maps on shared and global memory + auto smem_tensor_map_a = reinterpret_cast(smem_buffer); + auto smem_tensor_map_b = smem_tensor_map_a + 1; + auto gmem_tensor_map_a = tensor_map_buffer + blockIdx.x * 2; + auto gmem_tensor_map_b = gmem_tensor_map_a + 1; + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); + }); + auto smem_sfb = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); + }); + + // Barriers on shared memory + constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); + }); + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); + }); + + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // Load tensormap A/B to shared memory + if constexpr (kGemmType == GemmType::KGroupedContiguous) { + *smem_tensor_map_a = tensor_map_a_base; + *smem_tensor_map_b = tensor_map_b_base; + } + + // Initialize barriers + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Pipeline unroll control + constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages); + + // Register reconfigurations (more math registers are needed with unrolling) + constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); + constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // TMA and MMA pipeline + const auto get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase + }; + uint32_t iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + uint32_t last_group_idx = kNumGroups; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const uint32_t num_k_blocks = math::ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t m_idx = m_block_idx * BLOCK_M; + const uint32_t n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous && last_group_idx != scheduler.current_group_idx) { + last_group_idx = scheduler.current_group_idx; + + // Directly update current tensor map + const uint64_t current_k_offset = scheduler.current_k_cumsum; + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_a, gmem_a_ptr + current_k_offset * shape_m); + ptx::tensor_map_replace_global_addr_in_smem(smem_tensor_map_b, gmem_b_ptr + current_k_offset * shape_n); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a, scheduler.current_shape_k, scheduler.current_shape_k); + ptx::tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b, scheduler.current_shape_k, scheduler.current_shape_k); + *(gmem_tensor_map_a) = *(smem_tensor_map_a); + *(gmem_tensor_map_b) = *(smem_tensor_map_b); + ptx::tensor_map_release_gpu(); + + // Immediately acquire current tensor map + ptx::tensor_map_acquire_gpu(gmem_tensor_map_a); + ptx::tensor_map_acquire_gpu(gmem_tensor_map_b); + } + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + const uint32_t sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + const auto tensor_map_a_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_a : &tensor_map_a_base); + const auto tensor_map_b_ptr = (kGemmType == GemmType::KGroupedContiguous ? gmem_tensor_map_b : &tensor_map_b_base); + tma::copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma::copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma::copy(tensor_map_a_ptr, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma::copy(tensor_map_b_ptr, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Accumulation for WGMMA or CUDA promotion + DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); + const uint32_t current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t num_k_blocks = math::ceil_div(current_shape_k, BLOCK_K); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait TMA arrivals + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + full_barriers[stage_idx]->wait(phase); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ptx::ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ptx::ld_shared(smem_sfa[stage_idx] + r_1); + + // Read B scales + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ptx::ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = mma::sm90::make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + + // Promote with scales + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + // Flush previous stores + if (warp_idx % 4 == 0 and cute::elect_one_sync()) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Store to D shared memory + const auto smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + ptx::st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + ptx::st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Use TMA store to write back to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy( + &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, + current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh new file mode 100644 index 00000000..aa412484 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -0,0 +1,449 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_DEVICE void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { + if (num_former_iters == kNumFormerIters) { + func(cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + dispatch_num_former_iters(num_former_iters, func); +} + +template +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT( + math::constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or + (math::constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename mma::sm90::FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = math::constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = math::constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t shape_k_scales = math::ceil_div(shape_k, BLOCK_K); + const uint32_t shape_n_sfb = math::ceil_div(shape_n, BLOCK_K); + const uint32_t smem_sfb_size = math::align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t num_total_k_blocks = math::ceil_div(shape_k, BLOCK_K); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + auto full_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = sched::Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA A + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + tma::copy(&tensor_map_a, &full_barrier, + smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a, batch_idx); + tma::copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + num_tma_multicast_a); + + // Issue TMA B + tma::copy(&tensor_map_b, &full_barrier, + smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b, batch_idx); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + auto a_desc = mma::sm90::make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = mma::sm90::make_smem_desc(smem_b[0], 1); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; + const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; + auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; + + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + ptx::st_shared(smem_sfb + i, i < shape_k_scales ? local_sfb[i * stride_k_sfb] : local_sfb[(i - shape_k_scales) * stride_k_sfb + stride_n_sfb]); + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&]() { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); + } + }; + + // Skip useless computations + if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { + // The compiler must know the dynamic variable `num_former_iters`'s real value + constexpr bool kShouldOptimize = BLOCK_K / math::constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = math::constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // Dispatch `num_former_iters` and launch MMAs + dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { + #pragma unroll 8 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Read B scales + float scale_b_0 = ptx::ld_shared(smem_sfb + k_block_idx), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ptx::ld_shared(smem_sfb + k_block_idx + shape_k_scales); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ptx::ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; + b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; + WGMMA::wgmma(a_desc, b_desc, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(); + + // Skip promotion for the unfilled parts + if (not do_wgmma_store) + continue; + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + const bool predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + }); + } else { + #pragma unroll + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + empty_barrier_arrive(); + } + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + ptx::SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset); + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr, + n_idx, m_idx, scheduler.current_group_idx); + } else { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh new file mode 100644 index 00000000..225af441 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -0,0 +1,330 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint32_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + logits_dtype_t* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto num_q_blocks = math::ceil_div(seq_len, BLOCK_Q); + + // Types + using WGMMA = typename mma::sm90::FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Data on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + + // Initialize barriers + const bool is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 32; + constexpr uint32_t kNumMathRegisters = 112; + + // Block scheduler + const auto sm_idx = blockIdx.x; + uint32_t block_q_idx = sm_idx, q_iter_idx = 0; + const auto get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + kNumSMs, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = cu_seq_len_k_start[q_idx]; + seq_k_end[i] = cu_seq_len_k_end[q_idx]; + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + // TMA alignment requirements for SF KV + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, math::ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // Only the first warp remains + if (not is_tma_load_warp) + return; + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& thread_idx = threadIdx.x % kNumMathThreads; + const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = ptx::get_lane_idx(); + float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; + + const auto& warp_offset = warp_idx * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + const auto q_offset = (block_q_idx * BLOCK_Q + i) * static_cast(stride_logits); + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) + logits[q_offset + kv_offset + v_0_offset - seq_k_start[i]] = static_cast(v_0); + if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) + logits[q_offset + kv_offset + v_1_offset - seq_k_start[i]] = static_cast(v_1); + } else { + logits[q_offset + kv_offset + v_0_offset] = static_cast(v_0); + logits[q_offset + kv_offset + v_1_offset] = static_cast(v_1); + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh new file mode 100644 index 00000000..cc2592bb --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -0,0 +1,334 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint32_t logits_stride, const uint32_t block_table_stride, + const uint32_t* context_lens, logits_dtype_t* logits, + const uint32_t* block_table, const uint32_t* indices, + const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + DG_STATIC_ASSERT(not kIsVarlen, "Varlen is not supported for SM90 paged MQA logits"); + + // Types + using WGMMA = typename mma::sm90::FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto warpgroup_idx = warp_idx / 4; + const auto lane_idx = ptx::get_lane_idx(); + + // Prefetch TMA descriptors + static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = math::constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + + math::constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = math::constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + + math::constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q data and barriers on shared memory + auto smem_q = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = utils::PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + + // Separate math warpgroups and tma load warps into KV groups + // Each math warpgroup corresponds to a tma load warp + const auto kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + + // Per group KV data and barriers on shared memory + const auto smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = utils::PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + + // Initialize barriers + if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { + if (kv_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + } + if (kv_group_idx < kNumMathWarpGroups) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(128); + } + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 64; + constexpr uint32_t kNumMathRegisters = 104; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Scheduler + auto scheduler = sched::PagedMQALogitsScheduler( + blockIdx.x, batch_size, context_lens, schedule_meta, indices); + DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + + // Q and KV pipeline + const auto get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + if (kv_group_idx >= kNumMathWarpGroups) + return; + + const auto issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (kv_group_idx == 0 and cute::elect_one_sync()) { + tma::copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma::copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx * kNextN); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_atom_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? + block_table[q_idx * static_cast(block_table_stride) + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)] : 0); + } + const auto kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + if (cute::elect_one_sync()) { + tma::copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma::copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; + const auto sub_warp_offset = (warp_idx % 4) * 16; + const auto v_0_offset = lane_idx / 4 + 0; + const auto v_1_offset = lane_idx / 4 + 8; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ptx::ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * static_cast(logits_stride) + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = mma::sm90::make_smem_desc( + smem_kv[kv_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = mma::sm90::make_smem_desc( + smem_q[q_stage_idx] + k * WGMMA::K, + mma::sm90::to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + + // Read per-KV scales + float scale_kv_0 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ptx::ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + + // Wait WGMMA + ptx::warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * static_cast(logits_stride) + v_0_offset] = static_cast(v_0); + logits[kv_offset + i * static_cast(logits_stride) + v_1_offset] = static_cast(v_1); + } + } + } +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..93b14100 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,294 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_DEVICE +uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { + constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; + + const auto bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % kGroupsInSwizzleRange; + + return (row * kNumBankGroups + col) % kGroupsInSwizzleRange; +} + +template +CUTLASS_GLOBAL void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // kSwizzleAMode and kSwizzleBMode must be 128 for now + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode"); + + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = ptx::get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = utils::PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = utils::PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = math::constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 256; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // TMA load warp + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cutlass::arch::warpgroup_reg_dealloc(); + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma::copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma::copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + + for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { + const auto stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + } + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + constexpr uint32_t WGMMA_M = 64; + constexpr uint32_t WGMMA_N = BLOCK_N; + constexpr uint32_t WGMMA_K = 8; + + using WGMMA = typename mma::sm90::TF32MMASelector::type; + float accum[WGMMA::kNumAccum] = {0}; + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + float sqr_sum_acc_0 = 0; + float sqr_sum_acc_1 = 0; + + #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2 + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128; + constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K; + + float a[kNumRegPerWgmma * kNumWgmmaPerBlockK]; + // Assume swizzle A mode is 128 + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + + // Load BF16 A fragment from shared memory into registers, and transpose to FP32 + uint32_t row = warp_idx * 16 + lane_idx / 4; + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a + uint32_t bank_group_idx = (row ^ i) % 8; + nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + + uint32_t elem_offset = lane_idx % 4; + nv_bfloat16 a_bf16[kNumRegPerWgmma]; + a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset]; + a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4]; + a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset]; + a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4]; + + auto a_bf16x2_ptr = reinterpret_cast(a_bf16); + auto a_float2_ptr = reinterpret_cast(a); + float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]); + float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]); + a_float2_ptr[i * 2 + 0] = a_float2_0; + a_float2_ptr[i * 2 + 1] = a_float2_1; + sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x; + sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; + } + + ptx::warpgroup_wait<0>(); + if (s > 0) + empty_barriers[(s - 1) % kNumStages]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + ptx::warpgroup_arrive(); + + constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); + constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; + DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K"); + + #pragma unroll + for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { + #pragma unroll + for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { + auto b_desc = mma::sm90::make_smem_desc( + smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); + } + } + ptx::warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + ptx::warpgroup_fence_operand(accum[i]); + } + + const auto& reduced_sum_0 = math::warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = math::warp_reduce_sum<4>(sqr_sum_acc_1); + + const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); + if (lane_idx % 4 == 0) { + if (m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum_0; + if (m_idx + 8 < shape_m) + sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; + } + ptx::warpgroup_wait<0>(); + empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); + + // Write accum to shared memory + // Every 2 threads (one pair) will write to the same bank group (16 bytes). + // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d + uint32_t is_odd_pair = lane_idx / 2 % 2; + + // Four threads per group; write the data to the same row. + uint32_t row_idx = lane_idx / 4; + + // Even/odd index pairs write to the same column, we need to reorder idx: + // group even pair indices consecutively, and likewise for odd ones. + uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx; + + auto shifted_smem_ptr = reinterpret_cast(smem_cd) + + (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows + lane_idx % 2 * 8; // One thread of a pair writes 8 bytes + + #pragma unroll + for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) { + // Get the swizzled bank group index (16 bytes per group) + uint32_t bank_group_idx = get_swizzled_bank_group_idx(i + is_odd_pair, reordered_pair_idx); + auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group + + // 0/1 write to the same row, 2/3 write to another row + auto values = reinterpret_cast(accum + i * 2); + ptx::st_shared(smem_ptr, values[0], values[1]); + ptx::st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, 1); + + // Issue TMA stores + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh new file mode 100644 index 00000000..2f66b980 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -0,0 +1,74 @@ +#pragma once + +#include +#include + +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL __launch_bounds__(kNumWarps * 32, 1) +void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, logits_dtype_t* logits) { + const uint32_t num_sms = gridDim.x; + const uint32_t sm_idx = blockIdx.x; + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + constexpr uint32_t kAlignment = 16 / sizeof(logits_dtype_t); + const logits_dtype_t neg_inf = -cute::numeric_limits::infinity(); + + // Allocate filled `-inf` shared memory + extern __shared__ __align__(1024) logits_dtype_t smem_buffer[]; + #pragma unroll + for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) + smem_buffer[i] = neg_inf; + cute::tma_store_fence(); + __syncthreads(); + + // Assign sequence to each warp + const auto assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto per = total / num, rem = total % num; + return {start + idx * per + cute::min(idx, rem), per + (idx < rem)}; + }; + CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); + CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + if (cute::elect_one_sync()) { + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; + + for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { + const auto right = cute::min(left + BLOCK_KV, static_cast(stride_logits)); + if (right <= ks or ke <= left) { + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(logits_dtype_t)); + } else { + if (left < aligned_ks) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(logits_dtype_t)); + if (aligned_ke < right) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(logits_dtype_t)); + } + } + } + } + __syncwarp(); + + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto ks = cu_seq_len_k_start == nullptr ? 0 : cu_seq_len_k_start[i / kNextN]; + const auto ke = cu_seq_len_k_end[i / kNextN] - kNextN + i % kNextN + 1; + const auto aligned_ks = ks / kAlignment * kAlignment, aligned_ke = (ke + kAlignment - 1) / kAlignment * kAlignment; + for (uint32_t j = aligned_ks; j < ks; ++ j) + logits[i * stride_logits + j] = neg_inf; + for (uint32_t j = ke; j < aligned_ke; ++ j) + logits[i * stride_logits + j] = neg_inf; + } +} + +} diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh new file mode 100644 index 00000000..a977c554 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/impls/smxx_layout.cuh @@ -0,0 +1,189 @@ +#pragma once + +#include +#include +#include +#include + +namespace deep_gemm { + +template +CUTLASS_GLOBAL void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename utils::Vectorized::vec_t in_vec_t; + constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); + constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; + + // Shapes and strides + extern __shared__ float smem_buffer[]; + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); + + // Shift into the block + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; + const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Load + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { + auto in_vec = local_sf[i]; + const auto& in_values = reinterpret_cast(&in_vec); + + const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerVec; ++ j) + smem_buffer[row * PADDED_SF_K + col + j] = in_values[j]; + } + __syncthreads(); + + // Store + #pragma unroll + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { + const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; + const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ptx::ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + } +} + +// NOTES: the two kernels below always pack the K dimension + +template +CUTLASS_GLOBAL void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { + extern __shared__ uint32_t smem_buffer[]; + + // Shapes and strides + constexpr auto kNumPackedSFK = math::constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = math::align(mn, kNumTMAAlignedElems); + + // Shift into the group + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Load FP32 SFs + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); + const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + const auto num_values = in_block_mn * SF_K; + const auto num_uint4 = num_values / 4; + #pragma unroll + for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { + const auto& [x, y, z, w] = reinterpret_cast(local_sf)[i]; + ptx::st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + } + + // Fill unaligned values as well + if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) + ptx::st_shared(smem_buffer + unaligned_idx, local_sf[unaligned_idx]); + __syncthreads(); + + // Pack into UE8M0 and store + #pragma unroll + for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) { + const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN; + + // Load shared memory + uint32_t values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + const auto sf_k_idx = sf_k_pack_idx * 4 + j; + values[j] = sf_k_idx < SF_K ? ptx::ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + } + + // Pack and store + uint32_t packed = 0; + packed |= (values[0] >> 23u); + packed |= (values[1] >> 15u); + packed |= (values[2] >> 7u); + packed |= (values[3] << 1u); + if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) + out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; + } +} + +template +CUTLASS_GLOBAL void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k, + const uint32_t gran_k) { + // Always packing the K dimension + // NOTES: should also assert `mn % 4 == 0` at launch + DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes"); + + // Shapes and strides + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto in_block_mn_uint4 = in_block_mn / 4; + const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K); + + // Shift into the right block along MN + sf += blockIdx.x * BLOCK_MN; + out += blockIdx.x * BLOCK_MN; + + // Each warp is responsible for a packed row + const auto warp_idx = threadIdx.x / 32; + const auto lane_idx = ptx::get_lane_idx(); + const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; + if (warp_idx >= in_block_packed_sf_k) + return; + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + // Make an offset on the input + uint32_t input_offset = 0; + if constexpr (kNumGroups > 1) { + // Load each group's size + DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups"); + uint32_t group_ks[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) { + const auto group_idx = lane_idx * 4 + i; + group_ks[i] = group_idx < kNumGroups ? ks[group_idx] : 0; + } + __syncwarp(); + + // Make the offset + sf_k = 0; + uint32_t sum_packed_sf_k = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumGroups; ++ i) { + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / gran_k, i / 4); + sf_k += sf_k_in_group; + sum_packed_sf_k += math::ceil_div(sf_k_in_group, 4u); + if (packed_sf_k_idx < sum_packed_sf_k) + break; + if (const auto remainder = sf_k_in_group % 4; remainder > 0) + input_offset += 4 - remainder; + } + } + + for (uint32_t mn_idx = ptx::get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + // Load + uint4 values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + values[j] = make_uint4(0, 0, 0, 0); + if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) + values[j] = reinterpret_cast(sf + sf_k_idx * mn)[mn_idx]; + } + + // Pack and store + uint4 packed; + packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); + packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); + packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); + packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; + } +} + +} // namespace deep_gemm diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/layout/mega_moe.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/layout/mega_moe.cuh new file mode 100644 index 00000000..13520c60 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/layout/mega_moe.cuh @@ -0,0 +1,260 @@ +#pragma once + +#include + +#include +#include + +namespace deep_gemm::layout { + +static constexpr int kNumCandidateBlockMs = 7; +static constexpr int kCandidateBlockM[kNumCandidateBlockMs] = {8, 16, 32, 64, 96, 128, 192}; +static constexpr int kMaxCandidateBlockM = 192; +static constexpr int kMinCandidateBlockM = 8; +static constexpr int kLCMCandidateBlockM = 384; + +// Pool capacity for shared expert token pool: worst-case total tokens + per-expert BLOCK_M alignment padding, among all possible BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_max_pool_tokens(T num_ranks, T num_max_tokens_per_rank, T num_topk, + T num_experts_per_rank) { + const auto num_max_recv_tokens = num_ranks * num_max_tokens_per_rank; + const auto num_max_experts_per_token = math::constexpr_min(num_topk, num_experts_per_rank); + return math::constexpr_align( + num_max_recv_tokens * num_max_experts_per_token + num_experts_per_rank * (static_cast(kMaxCandidateBlockM) - 1), + static_cast(kLCMCandidateBlockM)); +} + +// SF pool capacity: all experts share a contiguous SF region, sized by pool blocks × SF_BLOCK_M +template +CUTLASS_HOST_DEVICE constexpr T get_num_padded_sf_pool_tokens(T num_max_pool_tokens, T block_m) { + return (num_max_pool_tokens / block_m) * math::constexpr_align(block_m, static_cast(128)); +} + +// Per-token source metadata for combine write-back +struct TokenSrcMetadata { + uint32_t rank_idx; + uint32_t token_idx; + uint32_t topk_idx; +}; + +struct Workspace { + void* base; + uint32_t num_ranks, num_experts; + uint32_t num_experts_per_rank; + uint32_t num_max_tokens_per_rank; + uint32_t num_max_recv_tokens_per_expert; + + // Pool capacity: all local experts share a contiguous token pool + uint32_t num_max_pool_tokens; + uint32_t num_max_pool_blocks; + + // For both grid barrier and NVLink barrier + static constexpr uint64_t kNumBarrierSignalBytes = 32; + + CUTLASS_HOST_DEVICE + Workspace(void* base, + const uint32_t& num_ranks, + const uint32_t& num_experts, + const uint32_t& num_max_tokens_per_rank, + const uint32_t& num_topk): + base(base), + num_ranks(num_ranks), num_experts(num_experts), + num_max_tokens_per_rank(num_max_tokens_per_rank) { + num_experts_per_rank = num_experts / num_ranks; + num_max_recv_tokens_per_expert = num_ranks * num_max_tokens_per_rank; + num_max_pool_tokens = get_num_max_pool_tokens(num_ranks, num_max_tokens_per_rank, num_topk, num_experts_per_rank); + num_max_pool_blocks = num_max_pool_tokens / kMinCandidateBlockM; + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + uint64_t num_bytes = 0; + + // Barrier + num_bytes += kNumBarrierSignalBytes; + + // Expert send/recv count + num_bytes += num_experts * sizeof(uint64_t) * 2; + + // Expert recv count sum + num_bytes += num_experts_per_rank * sizeof(uint64_t); + + // L1 arrival count (padded to even entry count for `uint64_t` alignment of L2 mask) + num_bytes += math::align(num_max_pool_blocks, 2u) * sizeof(uint32_t); + + // L2 block arrival mask + num_bytes += num_max_pool_blocks * sizeof(uint64_t); + + // Dispatch pulling source token-topk + num_bytes += num_experts_per_rank * num_ranks * num_max_recv_tokens_per_expert * sizeof(int); + + // Combine push source indices + num_bytes += num_max_pool_tokens * sizeof(TokenSrcMetadata); + + // Align to TMA descriptor requirements + num_bytes = math::align(num_bytes, 16); + return num_bytes; + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + // Grid sync counters: `kNumBarrierSignalBytes` layout + // [ 0..15]: 4 x `uint32_t` grid sync counters + // [16..20]: `uint32_t` NVLink barrier counter + // [20..27]: 2 x `int` NVLink barrier signals (phase 0 and 1) + static constexpr uint32_t kNumMaxGridSyncCounters = 4; + + template + CUTLASS_DEVICE + uint32_t* get_grid_sync_count_ptr() const { + DG_STATIC_ASSERT(kIndex < kNumMaxGridSyncCounters, "Grid sync index out of bounds"); + return static_cast(base) + kIndex; + } + + CUTLASS_DEVICE + uint32_t* get_nvl_barrier_counter_ptr() const { + return static_cast(base) + kNumMaxGridSyncCounters; + } + + CUTLASS_DEVICE + int* get_nvl_barrier_signal_ptr(const uint32_t& phase) const { + // NOTES: the signal is signed, as we may minus + return math::advance_ptr(base, (kNumMaxGridSyncCounters + 1) * sizeof(uint32_t) + phase * sizeof(int)); + } + + CUTLASS_DEVICE + uint64_t* get_expert_send_count_ptr(const uint32_t& expert_idx = 0) const { + return math::advance_ptr(base, kNumBarrierSignalBytes) + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_ptr( + const uint32_t& rank_idx = 0, const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts) + rank_idx * num_experts_per_rank + expert_idx; + } + + CUTLASS_DEVICE + uint64_t* get_expert_recv_count_sum_ptr(const uint32_t& expert_idx = 0) const { + return get_expert_send_count_ptr(num_experts * 2) + expert_idx; + } + + CUTLASS_DEVICE + uint32_t* get_l1_arrival_count_ptr(const uint32_t& pool_block_idx = 0) const { + const auto base = get_expert_recv_count_sum_ptr(num_experts_per_rank); + return reinterpret_cast(base) + pool_block_idx; + } + + CUTLASS_DEVICE + uint64_t* get_l2_arrival_mask_ptr(const uint32_t& pool_block_idx = 0) const { + // Pad L1 entry count to even so that the `l2_arrival_mask` is 8-byte aligned + const auto base = get_l1_arrival_count_ptr(math::align(num_max_pool_blocks, 2u)); + return reinterpret_cast(base) + pool_block_idx; + } + + // For dispatch pulling + CUTLASS_DEVICE + uint32_t* get_src_token_topk_idx_ptr( + const uint32_t& expert_idx = 0, const uint32_t& rank_idx = 0, const uint32_t& token_idx = 0) const { + const auto base = get_l2_arrival_mask_ptr(num_max_pool_blocks); + return reinterpret_cast(base) + + expert_idx * (num_ranks * num_max_recv_tokens_per_expert) + + rank_idx * num_max_recv_tokens_per_expert + token_idx; + } + + // For combine usages + CUTLASS_DEVICE + TokenSrcMetadata* get_token_src_metadata_ptr(const uint32_t& pool_token_idx = 0) const { + const auto base = reinterpret_cast(get_src_token_topk_idx_ptr(num_experts_per_rank)); + return base + pool_token_idx; + } +}; + +struct Data { + uint32_t num_bytes; + bool require_tma_alignment; + void* base; + + CUTLASS_HOST_DEVICE + constexpr explicit Data( + const uint32_t& num_bytes, + const bool& require_tma_alignment = true, + void* base = nullptr) : + num_bytes(num_bytes), require_tma_alignment(require_tma_alignment), base(base) { + DG_UNIFIED_ASSERT(num_bytes % 16 == 0 or not require_tma_alignment); + } + + template + CUTLASS_HOST_DEVICE constexpr dtype_t get_num_bytes() const { + return static_cast(num_bytes); + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE void set_base_ptr(void* ptr) { + base = ptr; + } +}; + +struct Buffer { + Data data_layout; + uint32_t num_ranks; + uint32_t num_max_tokens_per_rank; + + void* base; + + CUTLASS_HOST_DEVICE + Buffer(const Data& data_layout, + const uint32_t& num_ranks, + const uint32_t& max_num_tokens_per_rank, + void* base = nullptr) : + data_layout(data_layout), + num_ranks(num_ranks), num_max_tokens_per_rank(max_num_tokens_per_rank), + base(base) {} + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes_per_rank() const { + return num_max_tokens_per_rank * data_layout.get_num_bytes(); + } + + CUTLASS_HOST_DEVICE + uint64_t get_num_bytes() const { + return get_num_bytes_per_rank() * num_ranks; + } + + template + CUTLASS_HOST_DEVICE dtype_t* get_base_ptr() const { + return static_cast(base); + } + + CUTLASS_HOST_DEVICE + void* get_end_ptr() const { + return math::advance_ptr(base, get_num_bytes()); + } + + CUTLASS_HOST_DEVICE + Buffer get_rank_buffer(const uint32_t& rank_idx) const { + return { + data_layout, + 1, num_max_tokens_per_rank, + math::advance_ptr(base, get_num_bytes_per_rank() * rank_idx) + }; + } + + CUTLASS_HOST_DEVICE + Data get_data_buffer(const uint32_t& token_idx, const bool& global = false) const { + DG_DEVICE_ASSERT(num_ranks == 1 or global); + return Data( + data_layout.num_bytes, + data_layout.require_tma_alignment, + math::advance_ptr(base, data_layout.get_num_bytes() * token_idx) + ); + } +}; + +} // namespace deep_gemm::layout diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh new file mode 100644 index 00000000..7f11aabc --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/layout/sym_buffer.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include + +namespace deep_gemm::layout { + +constexpr static uint32_t kNumMaxRanks = 72; + +template +struct SymBuffer { + int64_t base; + int64_t offsets[kNumMaxRanks]; + uint32_t rank_idx; + + DG_STATIC_ASSERT(kNumRanks <= kNumMaxRanks, "Too many ranks"); + + SymBuffer() = default; + + template + explicit SymBuffer(const Container& c, const uint32_t& rank_idx): rank_idx(rank_idx) { + const auto size = static_cast(c.size()); + base = c[rank_idx]; + for (uint32_t i = 0; i < kNumMaxRanks; ++ i) + offsets[i] = i < size ? (c[i] - base) : 0; + } + +#if defined(__CUDA_ARCH__) or defined(__CLION_IDE__) + template + CUTLASS_DEVICE ptr_t get_base_ptr() const { + return reinterpret_cast(base); + } + + template + CUTLASS_DEVICE ptr_t map(const ptr_t& ptr, const uint32_t& dst_rank_idx) const { + int64_t mapped_ptr = offsets[dst_rank_idx] + reinterpret_cast(ptr); + return *reinterpret_cast(&mapped_ptr); + } +#endif +}; + +} // namespace deep_gemm::layout diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/mma/sm100.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/mma/sm100.cuh new file mode 100644 index 00000000..0c554f4c --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/mma/sm100.cuh @@ -0,0 +1,151 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace deep_gemm::mma::sm100 { + +/// Shared memory descriptor +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + const uint32_t& stride_byte_offset, const uint32_t& leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +CUTLASS_DEVICE +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +CUTLASS_DEVICE +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +/// UMMA descriptors +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : tma::get_inner_block_atom_size(); +} + +template +CUTLASS_DEVICE +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto layout_type = to_umma_layout_type(); + const auto num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = tma::get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +CUTLASS_DEVICE uint64_t make_runtime_instr_desc_with_sf_id( + cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptorBlockScaled& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +CUTLASS_DEVICE void update_instr_desc_with_umma_n( + cute::UMMA::InstrDescriptor& desc, const uint32_t& umma_n) { + desc.n_dim_ = umma_n >> 3; +} + +} // namespace deep_gemm::mma::sm100 diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/mma/sm90.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/mma/sm90.cuh new file mode 100644 index 00000000..2c061940 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/mma/sm90.cuh @@ -0,0 +1,293 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace deep_gemm::mma::sm90 { + +/// MMA +template +struct FP8MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + template + CUTLASS_DEVICE static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + template + CUTLASS_DEVICE static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + CUTLASS_DEVICE static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +/// Shared memory descriptor +template +CUTLASS_DEVICE cute::GmmaDescriptor +make_smem_desc(PointerType smem_ptr, const int& layout_type, + const uint32_t& leading_byte_offset = 0, + const uint32_t& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +CUTLASS_DEVICE +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +CUTLASS_DEVICE +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +CUTLASS_DEVICE +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + math::swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +} // namespace deep_gemm::mma::sm90 diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/ld_st.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/ld_st.cuh new file mode 100644 index 00000000..c3e03bec --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/ld_st.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Compatibility: 256 bits LD/ST instructions +#if defined(CUDART_VERSION) and CUDART_VERSION >= 13000 +using longlong4_t = longlong4_32a; +#define make_longlong4_t make_longlong4_32a +#else +struct alignas(32) longlong4_t { long long x, y, z, w; }; +CUTLASS_HOST_DEVICE longlong4_t make_longlong4_t( + const long long& x, const long long& y, const long long& z, const long long& w) { + return {x, y, z, w}; +} +#endif + +/// LD/ST matrix +// TODO: remove `struct` +struct SM90_U32x2_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + CUTLASS_DEVICE static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +template +struct SM90_U32x2_STSM_N { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +template +struct SM90_U32x4_STSM_T { + CUTLASS_DEVICE static void + copy(dtype_t src_0, dtype_t src_1, dtype_t src_2, dtype_t src_3, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[4] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1), + *reinterpret_cast(&src_2), *reinterpret_cast(&src_3)}; + asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16.trans [%0], {%1, %2, %3, %4};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), + "r"(src[0]), "r"(src[1]), "r"(src[2]), "r"(src[3])); + } +}; + +template +struct SM100_U8x4_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src = *reinterpret_cast(&src_0); + asm volatile("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src)); + } +}; + +template +struct SM100_U8x8_STSM_T { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + DG_STATIC_ASSERT(sizeof(dtype_t) == sizeof(uint32_t), "Invalid dtype"); + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +/// Shared memory +CUTLASS_DEVICE uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +CUTLASS_DEVICE void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +CUTLASS_DEVICE void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +CUTLASS_DEVICE void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +CUTLASS_DEVICE void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +CUTLASS_DEVICE void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +CUTLASS_DEVICE void st_shared_bulk(void* smem_ptr, const uint32_t& num_bytes) { + // `size` must be 64-bit before PTX ISA 9.0 + asm volatile("st.bulk.weak.shared::cta [%0], %1, 0;" :: + "l"(__cvta_generic_to_shared(smem_ptr)), "l"(static_cast(num_bytes))); +} + +/// Global memory +CUTLASS_DEVICE uint64_t ld_volatile(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.volatile.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.gpu.global.b32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_sys(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.sys.global.b64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE void st_relaxed_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("st.L1::no_allocate.relaxed.sys.global.u64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +/// Atomics +CUTLASS_DEVICE uint64_t atomic_add(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint64_t atomic_add_sys(const uint64_t* ptr, const uint64_t& value) { + uint64_t ret; + asm volatile("atom.sys.global.add.u64 %0, [%1], %2;" : "=l"(ret) : "l"(ptr), "l"(value)); + return ret; +} + +CUTLASS_DEVICE uint32_t atomic_add_rel(const uint32_t* ptr, const uint32_t& value) { + uint32_t ret; + asm volatile("atom.release.gpu.global.add.u32 %0, [%1], %2;" : "=r"(ret) : "l"(ptr), "r"(value)); + return ret; +} + +CUTLASS_DEVICE void red_add(const int* ptr, const int& value) { + asm volatile("red.gpu.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_or_rel_sys(const uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.sys.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_or_rel_gpu(uint64_t* ptr, const uint64_t& value) { + asm volatile("red.release.gpu.global.or.b64 [%0], %1;" :: "l"(ptr), "l"(value)); +} + +CUTLASS_DEVICE void red_add_rel(const uint32_t* ptr, const uint32_t& value) { + asm volatile("red.release.gpu.global.add.u32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE void red_add_rel_sys(const int* ptr, const int& value) { + asm volatile("red.release.sys.global.add.s32 [%0], %1;" :: "l"(ptr), "r"(value)); +} + +CUTLASS_DEVICE int ld_acq_sys(const int* ptr) { + int ret; + asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint32_t ld_acq_sys(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" : "=r"(ret) : "l"(ptr)); + return ret; +} + +CUTLASS_DEVICE uint64_t ld_acq_gpu(const uint64_t* ptr) { + uint64_t ret; + asm volatile("ld.acquire.gpu.global.u64 %0, [%1];" : "=l"(ret) : "l"(ptr)); + return ret; +} + +/// Predicated loads +CUTLASS_DEVICE longlong4_t ld_gez_pred(const longlong4_t* ptr, const int& pred) { + longlong4_t ret = make_longlong4_t(0, 0, 0, 0); + asm volatile( + "{\n\t" + " .reg .pred p;\n\t" + " setp.ge.s32 p, %5, 0;\n\t" + " @p ld.global.L2::256B.v4.s64 {%0, %1, %2, %3}, [%4];\n\t" + "}" + : "+l"(ret.x), "+l"(ret.y), "+l"(ret.z), "+l"(ret.w) + : "l"(ptr), "r"(pred) + : "memory"); + return ret; +} + +/// Prefetch +CUTLASS_DEVICE void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +} // namespace deep_gemm::ptx diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh new file mode 100644 index 00000000..528b3dd1 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/tcgen05.cuh @@ -0,0 +1,168 @@ +#pragma once + +namespace deep_gemm::ptx { + +/// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F8F6F4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF4_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + CUTLASS_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +/// Tensor memory operations +CUTLASS_DEVICE void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +CUTLASS_DEVICE void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +} // namespace deep_gemm::ptx diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/tma.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/tma.cuh new file mode 100644 index 00000000..1530a3ed --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/tma.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +namespace deep_gemm::ptx { + +// Tensor-map instructions +CUTLASS_DEVICE void tensor_map_release_gpu() { + asm volatile ("fence.proxy.tensormap::generic.release.gpu;" ::: "memory"); +} + +CUTLASS_DEVICE void tensor_map_acquire_gpu(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +CUTLASS_DEVICE void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +CUTLASS_DEVICE void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +/// TMA instructions +CUTLASS_DEVICE void mbarrier_arrive( + cutlass::arch::ClusterTransactionBarrier* ptr) { + asm volatile("mbarrier.arrive.shared::cta.b64 _, [%0]; \n\t" :: + "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_arrive_and_set_tx( + cutlass::arch::ClusterTransactionBarrier* ptr, const uint32_t& num_bytes) { + asm volatile("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" :: + "r"(num_bytes), "r"(static_cast(__cvta_generic_to_shared(ptr)))); +} + +CUTLASS_DEVICE void mbarrier_wait_and_flip_phase( + cutlass::arch::ClusterTransactionBarrier* ptr, uint32_t& phase) { + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" :: + "r"(static_cast(__cvta_generic_to_shared(ptr))), + "r"(phase), "r"(0x989680)); + phase ^= 1; +} + +CUTLASS_DEVICE void tma_load_1d( + const void* dst_ptr, const void* src_ptr, + cutlass::arch::ClusterTransactionBarrier* mbarrier_ptr, + const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_FIRST) { + // NOTES: normally, the loaded part will be evicted soon + asm volatile( + "cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint [%0], [%1], %2, [%3], %4;\n" :: + "r"(static_cast(__cvta_generic_to_shared(dst_ptr))), + "l"(src_ptr), + "r"(num_bytes), + "r"(static_cast(__cvta_generic_to_shared(mbarrier_ptr))), + "l"(hint) + : "memory"); +} + +CUTLASS_DEVICE void tma_store_1d( + const void* dst_ptr, const void* src_ptr, const uint32_t& num_bytes, + const cute::TMA::CacheHintSm90& hint = cute::TMA::CacheHintSm90::EVICT_NORMAL) { + // NOTES: normally, the stored part will be used soon + asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::cache_hint [%0], [%1], %2, %3;\n" :: + "l"(dst_ptr), + "r"(static_cast(__cvta_generic_to_shared(src_ptr))), + "r"(num_bytes), + "l"(hint) + : "memory"); +} + +template +__forceinline__ __device__ void tma_store_wait() { + // NOTES: this function does not have `.read` + asm volatile("cp.async.bulk.wait_group %0;" ::"n"(kNumRemainingWaits) : "memory"); +} + +CUTLASS_DEVICE +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier& mbarrier, + void* smem_ptr, const uint32_t& col_idx, const int4& row_idxs, const uint64_t& cache_hint) { + const auto smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + const auto mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +} // namespace deep_gemm::ptx diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/utils.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/utils.cuh new file mode 100644 index 00000000..5c27166b --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/utils.cuh @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +CUTLASS_DEVICE uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +CUTLASS_DEVICE void sync_aligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("bar.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +CUTLASS_DEVICE void sync_unaligned(const uint32_t& num_threads, const uint32_t& barrier_idx) { + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_idx), "r"(num_threads)); +} + +template +CUTLASS_DEVICE dtype_t exchange(dtype_t ptr, const uint32_t& src_lane_idx) { + DG_STATIC_ASSERT(sizeof(dtype_t) % sizeof(uint32_t) == 0, ""); + const auto send_int_values = reinterpret_cast(&ptr); + dtype_t recv_dtype; + auto recv_int_values = reinterpret_cast(&recv_dtype); + #pragma unroll + for (uint32_t i = 0; i < sizeof(dtype_t) / sizeof(uint32_t); ++ i) + recv_int_values[i] = __shfl_sync(0xffffffff, send_int_values[i], static_cast(src_lane_idx)); + return recv_dtype; +} + +CUTLASS_DEVICE void accumulate(float2& a, nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000) + // Use `add.rn.f32.bf16` instruction to perform fused (cast + add) operation on SM100 + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.x) : "h"(*reinterpret_cast(&b.x))); + asm("add.rn.f32.bf16 %0, %1, %0;\n" : "+f"(a.y) : "h"(*reinterpret_cast(&b.y))); +#else + const auto [x, y] = __bfloat1622float2(b); + a.x += x, a.y += y; +#endif +} + +} // namespace deep_gemm::ptx diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/wgmma.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/wgmma.cuh new file mode 100644 index 00000000..8912a157 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/ptx/wgmma.cuh @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace deep_gemm::ptx { + +CUTLASS_DEVICE void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +CUTLASS_DEVICE void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +CUTLASS_DEVICE void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +} // namespace deep_gemm::ptx diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/gemm.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/gemm.cuh new file mode 100644 index 00000000..5cd50c66 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/gemm.cuh @@ -0,0 +1,300 @@ +#pragma once + +#include +#include + +namespace deep_gemm::sched { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto candidate: {8u, 16u}) { + const auto usage = kIsMulticastOnA ? + candidate * BLOCK_N + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + math::constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for contiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + CUTLASS_DEVICE void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = grouped_layout[group_idx]; + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + CUTLASS_DEVICE explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, + const uint32_t& shape_k, int* grouped_layout = nullptr) { + num_m_blocks = math::ceil_div(shape_m, BLOCK_M); + num_n_blocks = math::ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = grouped_layout[0]; + num_m_blocks = math::ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + CUTLASS_DEVICE void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + CUTLASS_DEVICE uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, grouped_layout[m_block_idx * BLOCK_M]) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + // For swap A/B and psum layout only + CUTLASS_DEVICE uint32_t get_aligned_effective_m_in_block(const uint32_t& m_block_idx) const { + constexpr uint32_t UMMA_STEP_N = 16; + DG_STATIC_ASSERT(BLOCK_M % UMMA_STEP_N == 0, "Invalid alignment"); + if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) + return math::align(m_block_idx == last_psum_m / BLOCK_M + num_m_blocks - 1 ? current_psum_m - m_block_idx * BLOCK_M : BLOCK_M, UMMA_STEP_N); + return BLOCK_M; + } + + CUTLASS_DEVICE bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = math::ceil_div(static_cast(grouped_layout[current_group_idx]), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = math::align(current_psum_m, BLOCK_M); + current_psum_m = grouped_layout[current_group_idx]; + current_m_block_cumsum += num_m_blocks; + num_m_blocks = math::ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with block M + m_block_idx += last_psum_m / BLOCK_M; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += math::ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + CUTLASS_DEVICE bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto group_idx = grouped_layout[m_block_idx * BLOCK_M]; + const auto peer_group_idx = grouped_layout[(m_block_idx ^ 1) * BLOCK_M]; + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + CUTLASS_DEVICE bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return grouped_layout[m_offset + m_block_idx * BLOCK_M] >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < grouped_layout[current_group_idx]; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + return m_offset + m_block_idx * BLOCK_M < current_psum_m; + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm::sched diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh new file mode 100644 index 00000000..cdbecccd --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/mega_moe.cuh @@ -0,0 +1,221 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace deep_gemm::sched { + +// Computation phase for the current block +enum class BlockPhase { + None = 0, + Linear1 = 1, + Linear2 = 2 +}; + +template +struct MegaMoEScheduler { + DG_STATIC_ASSERT(L1_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_N % BLOCK_N == 0, "Invalid shape"); + DG_STATIC_ASSERT(L1_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(L2_SHAPE_K % BLOCK_K == 0, "Invalid shape"); + DG_STATIC_ASSERT(kNumExpertsPerRank % kNumExpertsPerWave == 0, "Invalid wave config"); + + // NOTES: N block counts must be even so that 2 adjacent CTAs in a cluster + // always land on the same m_block_idx with n_block_idx differing by 1 + DG_STATIC_ASSERT(kNumSMs % 2 == 0, "Number of SMs must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL1BlockNs % 2 == 0, "L1 N block count must be even for 2-CTA cluster"); + DG_STATIC_ASSERT(kNumL2BlockNs % 2 == 0, "L2 N block count must be even for 2-CTA cluster"); + + // Arrival counts + const layout::Workspace& workspace; + + // Scheduler state + BlockPhase next_phase = BlockPhase::Linear1; + + // Current expert and block indices + uint32_t current_local_expert_idx = 0; + uint32_t current_num_tokens = 0; + uint32_t current_pool_block_offset = 0; + uint32_t block_idx = 0; + uint32_t m_block_idx = 0; + uint32_t n_block_idx = 0; + + // Pre-cached per-expert token counts (filled during `for_each_block` init) + // Layout: `stored_num_tokens_per_expert[i]` holds expert (i * 32 + lane_idx)'s count + uint32_t stored_num_tokens_per_expert[kNumExpertsPerLane] = {}; + + CUTLASS_DEVICE explicit MegaMoEScheduler(const layout::Workspace& workspace): workspace(workspace) { + block_idx = blockIdx.x; + } + + CUTLASS_DEVICE uint32_t get_wave_expert_end_idx() const { + return math::align(current_local_expert_idx + 1, kNumExpertsPerWave); + } + + CUTLASS_DEVICE uint32_t get_num_tokens(const uint32_t& expert_idx) const { + uint32_t valid_value; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + valid_value = (expert_idx == i * 32 + ptx::get_lane_idx()) ? + stored_num_tokens_per_expert[i] : valid_value; + } + return ptx::exchange(valid_value, expert_idx % 32); + } + + // Get pool block offset for a given expert index from a per-lane token count array + CUTLASS_DEVICE uint32_t get_pool_block_offset(const uint32_t& expert_idx) { + uint32_t num_blocks = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + if (i * 32 + ptx::get_lane_idx() < expert_idx) + num_blocks += math::ceil_div(stored_num_tokens_per_expert[i], BLOCK_M); + } + return __reduce_add_sync(0xffffffff, num_blocks); + } + + CUTLASS_DEVICE void advance_expert_idx() { + current_pool_block_offset += get_current_num_m_blocks(); + current_local_expert_idx += 1; + current_num_tokens = get_num_tokens(current_local_expert_idx); + } + + CUTLASS_DEVICE void set_expert_idx(const uint32_t& expert_idx) { + current_local_expert_idx = expert_idx; + current_num_tokens = get_num_tokens(expert_idx); + current_pool_block_offset = get_pool_block_offset(expert_idx); + } + + CUTLASS_DEVICE uint32_t get_current_pool_block_offset() const { + return current_pool_block_offset; + } + + CUTLASS_DEVICE uint32_t get_current_num_m_blocks() const { + return math::ceil_div(current_num_tokens, BLOCK_M); + } + + template + CUTLASS_DEVICE uint32_t get_valid_m() const { + const auto m = cute::min(current_num_tokens - m_block_idx * BLOCK_M, BLOCK_M); + return kDoUMMAAligned ? math::align(m, 16u) : m; + } + + CUTLASS_DEVICE bool fetch_next_l1_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + m_block_idx = block_idx / kNumL1BlockNs; + if (m_block_idx < num_m_blocks) + return true; + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL1BlockNs; + advance_expert_idx(); + } + return false; + } + + CUTLASS_DEVICE bool fetch_next_l2_block() { + const auto wave_end_expert_idx = get_wave_expert_end_idx(); + while (current_local_expert_idx < wave_end_expert_idx) { + const auto num_m_blocks = get_current_num_m_blocks(); + if (block_idx < num_m_blocks * kNumL2BlockNs) { + m_block_idx = block_idx / kNumL2BlockNs; + return true; + } + + // Current expert is fully assigned, move to the next + block_idx -= num_m_blocks * kNumL2BlockNs; + advance_expert_idx(); + } + return false; + } + + // Core state machine: assigns the next block + CUTLASS_DEVICE cute::tuple get_next_block() { + while (true) { + if (current_local_expert_idx >= kNumExpertsPerRank) + break; + + if (next_phase == BlockPhase::Linear1) { + if (fetch_next_l1_block()) { + // Found a new L1 block + n_block_idx = block_idx - m_block_idx * kNumL1BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear1, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // L1 for the current wave is complete, transition to L2 + next_phase = BlockPhase::Linear2; + set_expert_idx(math::align(current_local_expert_idx - 1, kNumExpertsPerWave)); + } + } else { + if (fetch_next_l2_block()) { + // Found a new L2 block + n_block_idx = block_idx - m_block_idx * kNumL2BlockNs; + // Jump to next block + block_idx += kNumSMs; + return {BlockPhase::Linear2, current_local_expert_idx, m_block_idx, n_block_idx}; + } else { + // Move to L1 of the next wave + next_phase = BlockPhase::Linear1; + } + } + } + + // All waves and experts are fully processed + return {BlockPhase::None, 0, 0, 0}; + } + + CUTLASS_DEVICE void fetch_expert_recv_count() { + // NOTES: each lane caches experts at indices (i * 32 + lane_idx) + #pragma unroll + for (uint32_t i = 0; i < kNumExpertsPerLane; ++ i) { + const auto expert_idx = i * 32 + ptx::get_lane_idx(); + uint64_t value = 0; + if (expert_idx < kNumExpertsPerRank) { + do { + value = ptx::ld_volatile(workspace.get_expert_recv_count_sum_ptr(expert_idx)); + } while (static_cast(value >> 32) != kNumSMs * kNumRanks); + } + stored_num_tokens_per_expert[i] = static_cast(value); + } + __syncwarp(); + } + + template + CUTLASS_DEVICE void for_each_block(Func&& func) { + // Wait for all expert counters to be finalized + fetch_expert_recv_count(); + + // Initialize current expert with 0 + set_expert_idx(0); + + // Iterate over all blocks + // TODO: add swizzle within expert waves for better L2 cache utilization + while (true) { + CUTE_TIE_DECL(get_next_block(), block_phase, current_local_expert_idx, m_block_idx, n_block_idx); + if (block_phase == BlockPhase::None) + break; + + func(block_phase, current_local_expert_idx, + block_phase == BlockPhase::Linear2 ? kNumL2BlockKs : kNumL1BlockKs, + m_block_idx, n_block_idx); + } + } +}; + +} // namespace deep_gemm::sched diff --git a/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh new file mode 100644 index 00000000..548bbbc6 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/include/deep_gemm/scheduler/paged_mqa_logits.cuh @@ -0,0 +1,239 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm::sched { + +template +CUTLASS_GLOBAL __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, const uint32_t* indices, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = ptx::get_lane_idx(); + + // Wait for primary kernel completion + cudaGridDependencySynchronize(); + + __shared__ uint32_t varlen_atom_token_start[kAlignedBatchSize]; + __shared__ uint32_t varlen_atom_context_len[kAlignedBatchSize]; + __shared__ uint32_t varlen_num_atoms_shared; + uint32_t num_items; + + if constexpr (kIsVarlen) { + if (lane_idx == 0) { + uint32_t t = 0, atom_count = 0; + while (t < batch_size) { + varlen_atom_token_start[atom_count] = t; + const bool is_paired = (t + 1 < batch_size and indices[t] == indices[t + 1]); + varlen_atom_context_len[atom_count] = is_paired ? context_lens[t + 1] : context_lens[t]; + t += is_paired ? 2 : 1; + ++ atom_count; + } + varlen_num_atoms_shared = atom_count; + } + __syncwarp(); + num_items = varlen_num_atoms_shared; + } else { + num_items = batch_size; + } + + // Compute num_segs and prefix sum + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + uint32_t context_len; + if constexpr (kIsVarlen) { + context_len = (q_idx < num_items ? varlen_atom_context_len[q_idx] : 0); + } else { + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + context_len = (q_idx < batch_size ? context_lens[lens_idx] : 0); + } + num_segs[k] = math::ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + // SM work distribution + if constexpr (kIsVarlen) { + const uint32_t total = sum; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = num_items; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t atom_idx = lo; + const uint32_t kv_split_idx = (atom_idx == 0 ? seg_starts : seg_starts - prefix_sum[atom_idx - 1]); + const uint32_t q_atom_idx = (atom_idx < num_items ? varlen_atom_token_start[atom_idx] : batch_size); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } else { + const uint32_t next_n_atom = (next_n >= 2) ? 2 : 1; + const uint32_t num_next_n_atoms = math::ceil_div(next_n, next_n_atom); + const uint32_t total = sum * num_next_n_atoms; + const uint32_t q = total / kNumSMs, r = total % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t lo = 0, hi = batch_size; + while (lo < hi) { + const uint32_t mid = (lo + hi) / 2; + const bool pred = prefix_sum[mid] * num_next_n_atoms <= seg_starts; + lo = pred ? mid + 1 : lo; + hi = pred ? hi : mid; + } + const uint32_t q_idx = lo; + const uint32_t offset_in_q = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1] * num_next_n_atoms); + const uint32_t num_segs_q = (q_idx == 0 ? prefix_sum[0] : prefix_sum[q_idx] - prefix_sum[q_idx - 1]); + const uint32_t atom_idx = num_segs_q > 0 ? offset_in_q / num_segs_q : 0; + const uint32_t kv_split_idx = num_segs_q > 0 ? offset_in_q % num_segs_q : 0; + const uint32_t q_atom_idx = q_idx * num_next_n_atoms + atom_idx; + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_atom_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } + } +} + +// Conditional storage for varlen indices pointer (EBO: zero cost when unused) +template +struct IndicesStorage { + const uint32_t* indices; +}; + +template <> +struct IndicesStorage {}; + +template +struct PagedMQALogitsScheduler : IndicesStorage { + const uint32_t* context_lens; + uint32_t batch_size; + + uint32_t current_q_atom_idx, current_kv_idx; + uint32_t end_q_atom_idx, end_kv_idx; + uint32_t current_num_kv; + + CUTLASS_DEVICE static uint32_t atom_to_token_idx(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + static constexpr bool kPadOddN = (not kIsVarlen) and (kNextN % 2 == 1) and (kNextN >= 3); + static constexpr uint32_t kNextNAtom = (kIsVarlen or kNextN >= 2) ? 2 : 1; + if constexpr (kPadOddN) { + return q_atom_idx / kNumNextNAtoms * kNextN + q_atom_idx % kNumNextNAtoms * kNextNAtom; + } else { + return q_atom_idx * kNextNAtom; + } + } + } + + CUTLASS_DEVICE static uint32_t atom_to_block_table_row(const uint32_t& q_atom_idx) { + if constexpr (kIsVarlen) { + return q_atom_idx; + } else { + return q_atom_idx / kNumNextNAtoms; + } + } + + CUTLASS_DEVICE uint32_t get_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + const bool is_paired = (q_atom_idx + 1 < batch_size and + this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]); + const uint32_t ctx_len = is_paired ? context_lens[q_atom_idx + 1] : context_lens[q_atom_idx]; + return math::ceil_div(ctx_len, BLOCK_KV); + } else { + const uint32_t q_idx = q_atom_idx / kNumNextNAtoms; + const auto lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return math::ceil_div(context_lens[lens_idx], BLOCK_KV); + } + } + + CUTLASS_DEVICE explicit PagedMQALogitsScheduler(const uint32_t& sm_idx, const uint32_t& batch_size, + const uint32_t* context_lens, + const uint32_t* schedule_meta, const uint32_t* indices) { + this->context_lens = context_lens; + this->batch_size = batch_size; + if constexpr (kIsVarlen) { + this->indices = indices; + } + + const auto current_pack = reinterpret_cast(schedule_meta)[sm_idx]; + const auto end_pack = reinterpret_cast(schedule_meta)[sm_idx + 1]; + current_q_atom_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_atom_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_atom_idx); + } + + // Advance step in q_atom_idx space when moving to the next atom. + // Varlen: 1 or 2 depending on whether consecutive tokens share the same sequence. + // Non-varlen: always 1 (one atom unit). + CUTLASS_DEVICE uint32_t get_atom_advance(const uint32_t& q_atom_idx, const uint32_t& bound) const { + if constexpr (kIsVarlen) { + return (q_atom_idx + 1 < bound and this->indices[q_atom_idx] == this->indices[q_atom_idx + 1]) ? 2 : 1; + } else { + return 1; + } + } + + // Whether num_kv should be refreshed after advancing to q_atom_idx. + // Varlen: always refresh (each atom may have a different context_len). + // Non-varlen: only at atom-group boundaries (atoms within a group share context_len). + CUTLASS_DEVICE bool should_refresh_num_kv(const uint32_t& q_atom_idx) const { + if constexpr (kIsVarlen) { + return true; + } else { + return q_atom_idx % kNumNextNAtoms == 0; + } + } + + CUTLASS_DEVICE bool fetch_next_task(uint32_t &q_atom_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_atom_idx = current_q_atom_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (current_q_atom_idx == end_q_atom_idx and current_kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + current_kv_idx = 0; + current_q_atom_idx += get_atom_advance(current_q_atom_idx, end_q_atom_idx); + if (should_refresh_num_kv(current_q_atom_idx) and exist_q_atom_idx(current_q_atom_idx)) { + current_num_kv = get_num_kv(current_q_atom_idx); + } + } + return true; + } + + CUTLASS_DEVICE bool exist_q_atom_idx(const uint32_t& q_atom_idx) const { + return q_atom_idx < end_q_atom_idx or (q_atom_idx == end_q_atom_idx and 0 < end_kv_idx); + } +}; + +} // namespace deep_gemm::sched diff --git a/third_party/DeepGEMM/deep_gemm/legacy/__init__.py b/third_party/DeepGEMM/deep_gemm/legacy/__init__.py new file mode 100644 index 00000000..cce39ec7 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/legacy/__init__.py @@ -0,0 +1,5 @@ +# All kernels may be deprecated in the future (or rewrite in TileLang) +from .m_grouped_gemm import * +from .a_fused_m_grouped_gemm import * +from .a_fused_k_grouped_gemm import * +from .b_fused_k_grouped_gemm import * diff --git a/third_party/DeepGEMM/deep_gemm/legacy/a_fused_k_grouped_gemm.py b/third_party/DeepGEMM/deep_gemm/legacy/a_fused_k_grouped_gemm.py new file mode 100644 index 00000000..7b42f152 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/legacy/a_fused_k_grouped_gemm.py @@ -0,0 +1,88 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def a_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k + tl.arange(0, BLOCK_SIZE_K) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + rows[None, :] * M + + b_ptrs = b_ptr + k_range[:, None].to(tl.int64) * N + n_range[None, :] + a = tl.load(a_ptrs, mask=(rows >= 0)[None, :] & m_mask, other=0) + b = tl.load(b_ptrs, mask=n_mask, other=0) + acc = tl.dot(a, b, acc) + + # Write back + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def a_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == b.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert b.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K_, M = a.shape + K, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + a_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/third_party/DeepGEMM/deep_gemm/legacy/a_fused_m_grouped_gemm.py b/third_party/DeepGEMM/deep_gemm/legacy/a_fused_m_grouped_gemm.py new file mode 100644 index 00000000..41b35d53 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/legacy/a_fused_m_grouped_gemm.py @@ -0,0 +1,92 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def a_fused_m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, m_row_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # b block + rows = tl.load(m_row_indices_ptr + m_range).to(tl.int64) + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + k_mask = k_range < K + a_ptrs = a_ptr + rows[:, None] * K + k_range[None, :] + b_ptrs = b_ptr + batch_id * K * N + k_range[:, None] * (1 if IS_B_K_MAJOR else N) + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + a = tl.load(a_ptrs, mask=(rows >= 0)[:, None] & k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + d = acc.to(d_ptr.dtype.element_ty) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, d, mask=n_mask) + + +def a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + m_indices, m_row_indices = mappings + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous() or b.mT.is_contiguous()) and d.is_contiguous() + assert m_indices.is_contiguous() and m_row_indices.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 and d.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and m_row_indices.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and d.size(0) == m_indices.numel() and d.size(1) == r1 + assert m_indices.numel() == m_row_indices.numel() + assert m_indices.numel() % get_mk_alignment_for_contiguous_layout() == 0 + + if d.size(0) == 0: + return d + + M_, K = a.shape + B, K, N = r0, r2, r1 + M = m_indices.numel() + + grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE_M']) * triton.cdiv(N, meta['BLOCK_SIZE_N']), ) + a_fused_m_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, m_indices, m_row_indices, + M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def a_fused_m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + mappings: Tuple[torch.Tensor, torch.Tensor]): + a_fused_m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, mappings) diff --git a/third_party/DeepGEMM/deep_gemm/legacy/b_fused_k_grouped_gemm.py b/third_party/DeepGEMM/deep_gemm/legacy/b_fused_k_grouped_gemm.py new file mode 100644 index 00000000..7df8741f --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/legacy/b_fused_k_grouped_gemm.py @@ -0,0 +1,86 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_k_grouped_gemm_configs(), key=[], restore_value=['d_ptr']) +@triton.jit +def b_fused_k_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + k_indices_ptr, k_start_ptr, k_end_ptr, + M: tl.constexpr, + N: tl.constexpr, + K, + ACC: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + pid_b = (pid // (num_pid_m * num_pid_n)).to(tl.int64) + pid = pid % (num_pid_m * num_pid_n) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + k_start = tl.load(k_start_ptr + pid_b) + k_end = tl.load(k_end_ptr + pid_b) + + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_range = tl.max_contiguous(tl.multiple_of(m_range, BLOCK_SIZE_M), BLOCK_SIZE_M) + n_range = tl.max_contiguous(tl.multiple_of(n_range, BLOCK_SIZE_N), BLOCK_SIZE_N) + m_mask = (m_range < M)[:, None] + n_mask = (n_range < N)[None, :] + + if k_start >= k_end: + if not ACC: + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=m_mask & n_mask) + return + + # Compute + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(k_start, k_end, BLOCK_SIZE_K): + k_range = k.to(tl.int64) + tl.arange(0, BLOCK_SIZE_K).to(tl.int64) + rows = tl.load(k_indices_ptr + k_range).to(tl.int64) + a_ptrs = a_ptr + m_range[:, None] + k_range[None, :] * M + b_ptrs = b_ptr + rows[:, None] * N + n_range[None, :] + a = tl.load(a_ptrs, mask=m_mask, other=0.0) + b = tl.load(b_ptrs, mask=(rows >= 0)[:, None] & n_mask, other=0.0) + acc = tl.dot(a, b, acc) + + d_ptrs = d_ptr + pid_b * M * N + m_range[:, None].to(tl.int64) * N + n_range[None, :] + if ACC: + acc += tl.load(d_ptrs, mask=m_mask & n_mask) + acc = acc.to(d_ptr.dtype.element_ty) + tl.store(d_ptrs, acc, mask=m_mask & n_mask) + + +def b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + handle: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], acc: bool): + k_indices, k_start, k_end = handle + + assert a.is_contiguous() and b.is_contiguous() and d.is_contiguous() + assert k_indices.is_contiguous() and k_start.is_contiguous() and k_end.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert k_indices.dtype == torch.int32 and k_start.dtype == torch.int32 and k_end.dtype == torch.int32 + assert a.dim() == 2 and b.dim() == 2 and d.dim() == 3 + assert k_start.numel() == k_end.numel() and k_indices.size(0) == a.size(0) + assert d.size(0) == k_start.numel() and d.size(1) == a.size(1) and d.size(2) == b.size(1) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + + K, M = a.shape + K_, N = b.shape + B = k_start.numel() + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']) * B,) + b_fused_k_grouped_bf16_gemm_contiguous_tl_impl[grid](a, b, d, k_indices, k_start, k_end, M, N, K, ACC=acc) diff --git a/third_party/DeepGEMM/deep_gemm/legacy/m_grouped_gemm.py b/third_party/DeepGEMM/deep_gemm/legacy/m_grouped_gemm.py new file mode 100644 index 00000000..e685a9ab --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/legacy/m_grouped_gemm.py @@ -0,0 +1,84 @@ +import torch +import triton +import triton.language as tl +from typing import Tuple + +from .tune_options import * +from .._C import get_mk_alignment_for_contiguous_layout + + +@triton.autotune(configs=get_m_grouped_gemm_configs(), key=[]) +@triton.jit +def m_grouped_bf16_gemm_contiguous_tl_impl(a_ptr, b_ptr, d_ptr, + m_indices_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + IS_B_K_MAJOR: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + m_range = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + n_range = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + n_mask = (n_range < N)[None, :] + + # Empty tokens + batch_id = tl.load(m_indices_ptr + pid_m * BLOCK_SIZE_M).to(tl.int64) + if batch_id < 0: + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=d_ptr.dtype.element_ty), mask=n_mask) + return + + # Compute + a_ptrs = a_ptr + m_range[:, None].to(tl.int64) * K + tl.arange(0, BLOCK_SIZE_K)[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + b_ptrs = b_ptr + batch_id * K * N + \ + tl.arange(0, BLOCK_SIZE_K)[:, None].to(tl.int64) * (1 if IS_B_K_MAJOR else N) + \ + n_range[None, :].to(tl.int64) * (K if IS_B_K_MAJOR else 1) + for k in range(0, K, BLOCK_SIZE_K): + k_mask = (k + tl.arange(0, BLOCK_SIZE_K)) < K + a = tl.load(a_ptrs, mask=k_mask[None, :], other=0.0) + b = tl.load(b_ptrs, mask=k_mask[:, None] & n_mask, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * (1 if IS_B_K_MAJOR else N) + + # Write back + d_ptrs = d_ptr + m_range[:, None].to(tl.int64) * N + n_range[None, :] + tl.store(d_ptrs, accumulator.to(d_ptr.dtype.element_ty), mask=n_mask) + + +def m_grouped_bf16_gemm_nt_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + r0, r1, r2 = b.shape + + assert a.is_contiguous() and (b.is_contiguous or b.mT.is_contiguous()) + assert m_indices.is_contiguous() and d.is_contiguous() + assert a.dtype == torch.bfloat16 and b.dtype == torch.bfloat16 + assert m_indices.dtype == torch.int32 and d.dtype == torch.bfloat16 + assert a.dim() == 2 and b.dim() == 3 and d.dim() == 2 + assert a.size(1) == r2 and a.size(0) == d.size(0) and r1 == d.size(1) + assert m_indices.numel() == a.size(0) + assert a.size(0) % get_mk_alignment_for_contiguous_layout() == 0 + M, K = a.shape + B, N, K_ = r0, r1, r2 + + # For Triton 2.0, persistent kernel will lead to errors + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + m_grouped_bf16_gemm_contiguous_tl_impl[grid]( + a, b, d, m_indices, M, N, K, IS_B_K_MAJOR=b.is_contiguous()) + + +def m_grouped_bf16_gemm_nn_contiguous_tl(a: torch.Tensor, b: torch.Tensor, d: torch.Tensor, + m_indices: torch.Tensor): + m_grouped_bf16_gemm_nt_contiguous_tl(a, b.mT, d, m_indices) diff --git a/third_party/DeepGEMM/deep_gemm/legacy/tune_options.py b/third_party/DeepGEMM/deep_gemm/legacy/tune_options.py new file mode 100644 index 00000000..ed6a7f77 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/legacy/tune_options.py @@ -0,0 +1,28 @@ +from triton import Config +from .._C import get_mk_alignment_for_contiguous_layout + + +def get_config_smem_size(config: Config, elem_bytes: int = 2): + # NOTES: FP8 kernels will not use Triton, so by default we assume BF16 kernels + return (config.kwargs['BLOCK_SIZE_M'] + config.kwargs['BLOCK_SIZE_N']) * config.kwargs['BLOCK_SIZE_K'] * elem_bytes * config.num_stages + + +_gemm_configs = [ + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), + Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=8), + Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), + Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=4), +] + +# NOTES: we only consider A100 shared memory sizes here, as legacy kernels are only used for Ampere +_gemm_configs = list(filter(lambda x: get_config_smem_size(x) <= 166912, _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +_gemm_configs = list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) + +get_m_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_M'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) +get_k_grouped_gemm_configs = lambda: list(filter(lambda x: x.kwargs['BLOCK_SIZE_K'] <= get_mk_alignment_for_contiguous_layout(), _gemm_configs)) diff --git a/third_party/DeepGEMM/deep_gemm/mega/__init__.py b/third_party/DeepGEMM/deep_gemm/mega/__init__.py new file mode 100644 index 00000000..e624ecf2 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/mega/__init__.py @@ -0,0 +1,128 @@ +import torch +from typing import Tuple, Optional +from ..utils.math import align + +# noinspection PyBroadException +try: + # noinspection PyProtectedMember + import torch.distributed._symmetric_memory as symm_mem + import torch.distributed as dist +except Exception as exception: + print(f'Failed to load mega kernels, please check your PyTorch version: {exception}') + +from .. import _C + + +class SymmBuffer: + def __init__(self, group: dist.ProcessGroup, + # MoE arguments + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu'): + self.group = group + self.num_experts = num_experts + self.num_max_tokens_per_rank = num_max_tokens_per_rank + self.num_topk = num_topk + self.hidden = hidden + self.intermediate_hidden = intermediate_hidden + + # Allocate a symmetric buffer + num_bytes, slice_input_buffers = _C.get_symm_buffer_size_for_mega_moe( + group.size(), num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + self.buffer = symm_mem.empty(num_bytes, dtype=torch.int8, device='cuda') + self.handle = symm_mem.rendezvous(self.buffer, group=group) + self.buffer.zero_() + self.group.barrier() + torch.cuda.synchronize() + + # Create input buffer views + (self.x, self.x_sf, + self.topk_idx, self.topk_weights, + self.l1_acts, self.l1_acts_sf, + self.l2_acts, self.l2_acts_sf) = slice_input_buffers(self.buffer) + + def destroy(self): + self.handle = None + self.buffer = None + self.group = None + self.x = None + self.x_sf = None + + +def get_symm_buffer_for_mega_moe(group: dist.ProcessGroup, + num_experts: int, + num_max_tokens_per_rank: int, num_topk: int, + hidden: int, intermediate_hidden: int, + use_fp8_dispatch: bool = True, + activation: str = 'swiglu') -> SymmBuffer: + # Token count must be aligned to block sizes + num_max_tokens_per_rank = align(num_max_tokens_per_rank, _C.get_token_alignment_for_mega_moe()) + + return SymmBuffer( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden, + use_fp8_dispatch, activation + ) + + +def _interleave_l1_weights(l1_weights: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + # [gate: 0..7, up: 0..7, gate: 8..15, up: 8..15, ...] instead of [gate | up] + def interleave(t, gran: int = 8) -> torch.Tensor: + g, n, *rest = t.shape + half = n // 2 + gate = t[:, :half].reshape(g, half // gran, gran, *rest) + up = t[:, half:].reshape(g, half // gran, gran, *rest) + return torch.empty_like(t).copy_(torch.stack([gate, up], dim=2).reshape(g, n, *rest)) + + return interleave(l1_weights[0]), interleave(l1_weights[1]) + + +def _transpose_sf_for_utccp(sf: torch.Tensor) -> torch.Tensor: + num_groups, mn, packed_sf_k = sf.shape + assert sf.dtype == torch.int and mn % 128 == 0 + result = (sf.reshape(num_groups, -1, 4, 32, packed_sf_k) + .transpose(2, 3) + .reshape(num_groups, mn, packed_sf_k)) + return torch.empty_like(sf).copy_(result) + + +def transform_weights_for_mega_moe( + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor] +) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: + # L1: interleave gate/up, then transpose SF for UTCCP + l1_interleaved = _interleave_l1_weights(l1_weights) + l1_weights = (l1_interleaved[0], _transpose_sf_for_utccp(l1_interleaved[1])) + # L2: only transpose SF for UTCCP + l2_weights = (l2_weights[0], _transpose_sf_for_utccp(l2_weights[1])) + return l1_weights, l2_weights + + +def fp8_fp4_mega_moe(y: torch.Tensor, + l1_weights: Tuple[torch.Tensor, torch.Tensor], + l2_weights: Tuple[torch.Tensor, torch.Tensor], + sym_buffer: SymmBuffer, + cumulative_local_expert_recv_stats: Optional[torch.Tensor] = None, + recipe: Tuple[int, int, int] = (1, 1, 32), + activation: str = 'swiglu', + activation_clamp: Optional[float] = None, + fast_math: bool = True): + _C.fp8_fp4_mega_moe( + y, + l1_weights, l2_weights, + cumulative_local_expert_recv_stats, + sym_buffer.buffer, + sym_buffer.handle.buffer_ptrs, sym_buffer.group.rank(), + sym_buffer.num_max_tokens_per_rank, + sym_buffer.num_experts, sym_buffer.num_topk, + recipe, + activation, activation_clamp, + fast_math + ) diff --git a/third_party/DeepGEMM/deep_gemm/testing/__init__.py b/third_party/DeepGEMM/deep_gemm/testing/__init__.py new file mode 100644 index 00000000..13a9d78d --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/testing/__init__.py @@ -0,0 +1,4 @@ +from . import bench, numeric, utils +from .bench import * +from .numeric import * +from .utils import * diff --git a/third_party/DeepGEMM/deep_gemm/testing/bench.py b/third_party/DeepGEMM/deep_gemm/testing/bench.py new file mode 100644 index 00000000..552b9aa1 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/testing/bench.py @@ -0,0 +1,146 @@ +import os +import sys +import torch +from typing import Callable, Optional + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests / 1e3 + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False, + barrier: Optional[Callable] = None): + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) + profiler = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule, acc_events=True) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + if barrier is not None: + # NOTES: use a large kernel and a barrier to eliminate the unbalanced CPU launch overhead + # noinspection PyProtectedMember + torch.cuda._sleep(int(2e7)) # ~10ms + barrier() + fn() + torch.cuda.synchronize() + profiler.step() + + # Parse the profiling table + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table {prof_lines}' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num if total_num > 0 else 0) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/third_party/DeepGEMM/deep_gemm/testing/numeric.py b/third_party/DeepGEMM/deep_gemm/testing/numeric.py new file mode 100644 index 00000000..a42c4318 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/testing/numeric.py @@ -0,0 +1,21 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/third_party/DeepGEMM/deep_gemm/testing/utils.py b/third_party/DeepGEMM/deep_gemm/testing/utils.py new file mode 100644 index 00000000..2d202d41 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/third_party/DeepGEMM/deep_gemm/utils/__init__.py b/third_party/DeepGEMM/deep_gemm/utils/__init__.py new file mode 100644 index 00000000..a0dc6f78 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/utils/__init__.py @@ -0,0 +1,4 @@ +from . import math, layout +from .layout import * +from .math import * +from .dist import init_dist, uneven_all_gather diff --git a/third_party/DeepGEMM/deep_gemm/utils/dist.py b/third_party/DeepGEMM/deep_gemm/utils/dist.py new file mode 100644 index 00000000..426c3967 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/utils/dist.py @@ -0,0 +1,74 @@ +import inspect +import os +import torch +import torch.distributed as dist +from typing import Tuple + +_local_rank = None + + +def init_dist(local_rank: int, num_local_ranks: int) -> Tuple[int, int, dist.ProcessGroup]: + # NOTES: you may rewrite this function with your own cluster settings + ip = os.getenv('MASTER_ADDR', '127.0.0.1') + port = int(os.getenv('MASTER_PORT', '8361')) + num_nodes = int(os.getenv('WORLD_SIZE', 1)) + node_rank = int(os.getenv('RANK', 0)) + + # Set local rank + global _local_rank + _local_rank = local_rank + + sig = inspect.signature(dist.init_process_group) + params = { + 'backend': 'nccl', + 'init_method': f'tcp://{ip}:{port}', + 'world_size': num_nodes * num_local_ranks, + 'rank': node_rank * num_local_ranks + local_rank, + } + if 'device_id' in sig.parameters: + # noinspection PyTypeChecker + params['device_id'] = torch.device(f'cuda:{local_rank}') + dist.init_process_group(**params) + torch.set_default_device('cuda') + torch.cuda.set_device(local_rank) + + return dist.get_rank(), dist.get_world_size(), dist.new_group(list(range(num_local_ranks * num_nodes))) + + +def uneven_all_gather(tensor: torch.Tensor, dim: int = 0, group: dist.ProcessGroup = None) -> torch.Tensor: + world_size = dist.get_world_size(group) + + # Exchange sizes + local_dim_size = torch.tensor([tensor.shape[dim]], device=tensor.device, dtype=torch.long) + all_dim_sizes = [torch.zeros_like(local_dim_size) for _ in range(world_size)] + dist.all_gather(all_dim_sizes, local_dim_size, group=group) + all_dim_sizes = [s.item() for s in all_dim_sizes] + max_dim_size = max(all_dim_sizes) + + # Pad + if tensor.shape[dim] < max_dim_size: + pad_shape = list(tensor.shape) + pad_shape[dim] = max_dim_size - tensor.shape[dim] + padding = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + tensor_padded = torch.cat([tensor, padding], dim=dim) + else: + tensor_padded = tensor.contiguous() + + # All-gather + gathered = [torch.zeros_like(tensor_padded) for _ in range(world_size)] + dist.all_gather(gathered, tensor_padded, group=group) + + # Remove padding + trimmed = [ + torch.narrow(gathered[i], dim, 0, all_dim_sizes[i]) + for i in range(world_size) + ] + return torch.cat(trimmed, dim=dim) + + +def dist_print(s: str = '', once_in_node: bool = False) -> None: + global _local_rank + assert _local_rank is not None + if not once_in_node or _local_rank == 0: + print(s, flush=True) + dist.barrier() diff --git a/third_party/DeepGEMM/deep_gemm/utils/layout.py b/third_party/DeepGEMM/deep_gemm/utils/layout.py new file mode 100644 index 00000000..6512c5ab --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/utils/layout.py @@ -0,0 +1,21 @@ +try: + from .._C import ( + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor + ) +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Valid for all CUDA versions +from .._C import ( + set_mk_alignment_for_contiguous_layout, + get_mk_alignment_for_contiguous_layout, + get_theoretical_mk_alignment_for_contiguous_layout, +) + +# Some alias +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/third_party/DeepGEMM/deep_gemm/utils/math.py b/third_party/DeepGEMM/deep_gemm/utils/math.py new file mode 100644 index 00000000..f1582ed5 --- /dev/null +++ b/third_party/DeepGEMM/deep_gemm/utils/math.py @@ -0,0 +1,143 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + bits = x.abs().float().view(torch.int) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).bool().int() + return (exp.clamp(1, 254) << 23).view(torch.float) + + +def pack_ue8m0_to_int(x: torch.Tensor): + assert x.dtype == torch.float and x.size(-1) % 4 == 0 + assert (x.view(torch.int) & ((1 << 23) - 1) == 0).all() + return (x.view(torch.int) >> 23).to(torch.uint8).view(torch.int) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + padded_n = align(n, gran_k) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, padded_n // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, padded_n // gran_k).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous() + return x_fp8, pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 + m, n = x.shape + x_view = x.view(-1, gran_k, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code.view(torch.int8) + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + m, n = x.shape + assert n % 2 == 0 + assert not use_packed_ue8m0 or use_ue8m0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # int8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # int8 + return packed[:, :n // 2].contiguous(), pack_ue8m0_to_int(sf) if use_packed_ue8m0 else sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.int8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.int8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() + + +def _dequantize_from_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + fp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=x.device, dtype=torch.float) + sign, value_idx = (x & 0x08) != 0, (x & 0x07).to(torch.int) + value = fp4_values[value_idx] + return torch.where(sign & (value_idx != 0), -value, value) + + +def unpack_ue8m0_from_int(packed_sf: torch.Tensor) -> torch.Tensor: + return (packed_sf.view(torch.uint8).to(torch.int) << 23).view(torch.float) + + +def cast_back_from_fp4(packed: torch.Tensor, sf: torch.Tensor, gran_k: int = 128, + use_packed_ue8m0: bool = False) -> torch.Tensor: + m, n2 = packed.shape + n = n2 * 2 + if use_packed_ue8m0: + sf = unpack_ue8m0_from_int(sf) + unpacked = torch.zeros((m, n), dtype=torch.int8, device=packed.device) + unpacked[:, ::2] = packed & 0x0F + unpacked[:, 1::2] = (packed >> 4) & 0x0F + x_dequantized = _dequantize_from_fp4_e2m1(unpacked) + group_idx = torch.arange(n, device=packed.device) // gran_k + x_restored = x_dequantized * sf[:, group_idx] + return x_restored \ No newline at end of file diff --git a/third_party/DeepGEMM/develop.sh b/third_party/DeepGEMM/develop.sh new file mode 100755 index 00000000..e784347a --- /dev/null +++ b/third_party/DeepGEMM/develop.sh @@ -0,0 +1,25 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Link CUTLASS includes +ln -sf $script_dir/third-party/cutlass/include/cutlass deep_gemm/include +ln -sf $script_dir/third-party/cutlass/include/cute deep_gemm/include + +# Remove old dist file, build files, and build +rm -rf build dist +rm -rf *.egg-info +python setup.py build + +# Find the .so file in build directory and create symlink in current directory +so_file=$(find build -name "*.so" -type f | head -n 1) +if [ -n "$so_file" ]; then + ln -sf "../$so_file" deep_gemm/ +else + echo "Error: No SO file found in build directory" >&2 + exit 1 +fi + +# Open users' original directory +cd "$original_dir" diff --git a/third_party/DeepGEMM/install.sh b/third_party/DeepGEMM/install.sh new file mode 100755 index 00000000..5c7021c6 --- /dev/null +++ b/third_party/DeepGEMM/install.sh @@ -0,0 +1,13 @@ +# Change current directory into project root +original_dir=$(pwd) +script_dir=$(realpath "$(dirname "$0")") +cd "$script_dir" + +# Remove old dist file, build files, and install +rm -rf build dist +rm -rf *.egg-info +python setup.py bdist_wheel +pip install dist/*.whl --force-reinstall + +# Open users' original directory +cd "$original_dir" diff --git a/third_party/DeepGEMM/scripts/generate_pyi.py b/third_party/DeepGEMM/scripts/generate_pyi.py new file mode 100644 index 00000000..df7490d4 --- /dev/null +++ b/third_party/DeepGEMM/scripts/generate_pyi.py @@ -0,0 +1,890 @@ +import re +from pathlib import Path + + +def build_cpp_function_index(root_path): + func_index = {} + extensions = {'.cpp', '.cc', '.cxx', '.c', '.hpp', '.h'} + + pattern = re.compile( + r'([\w:\s*<&>,\[\]\(\)]+?)' + r'\s+' + r'([a-zA-Z_][a-zA-Z0-9_:]*)' + r'\s*\(', + ) + + for file_path in Path(root_path).rglob('*'): + if file_path.suffix.lower() not in extensions: + continue + if not file_path.is_file(): + continue + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + except Exception as e: + print(f'Failed to read file {file_path}: {e}') + continue + + # Remove the compile directives and comments + lines = content.split('\n') + clean_lines = [line for line in lines if not line.strip().startswith(('#', '//'))] + content = '\n'.join(clean_lines) + + for match in pattern.finditer(content): + return_type_part = match.group(1).strip() + full_func_name = match.group(2).strip() + + if not return_type_part or not re.match(r'^[a-zA-Z_]', return_type_part): + continue + + first_token = return_type_part.split()[0] + if first_token in {'return', 'if', 'else', 'for', 'while', 'switch', 'case', 'throw', 'catch', 'auto'}: + continue + + # Extract base name + if '::' in full_func_name: + base_name = full_func_name.split('::')[-1] + else: + base_name = full_func_name + + if not re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', base_name): + continue + + # Find matching ')' + paren_start = match.end() - 1 + paren_count = 0 + pos = paren_start + while pos < len(content): + ch = content[pos] + if ch == '(': + paren_count += 1 + elif ch == ')': + paren_count -= 1 + if paren_count == 0: + break + elif paren_count < 0: + pos = -1 + break + pos += 1 + else: + continue + + if pos == -1: + continue + + # Check context before match: should be at statement boundary + match_start = match.start() + context_before = content[max(0, match_start - 50):match_start] + if context_before and re.search(r'[a-zA-Z0-9_]$', context_before.rstrip()): + continue + + # Check for definition or header declaration + is_header = file_path.suffix.lower() in {'.h', '.hpp', '.cuh'} + after_paren = content[pos+1:pos+500] + has_brace = '{' in after_paren + has_semicolon = ';' in after_paren.split('{')[0] + + if has_brace or (is_header and has_semicolon): + sig_start = match.start(1) + full_signature = content[sig_start:pos+1].strip() + if base_name not in func_index: + func_index[base_name] = full_signature + + return func_index + + +class BracketTracker: + """ + Tracks nesting levels of various brackets in C++ code: + - () → paren + - [] → bracket + - {} → brace + - <> → angle (treated as template brackets only at top level) + Provides is_top_level() to check if currently outside all brackets. + """ + def __init__(self): + self.paren = 0 # () + self.bracket = 0 # [] + self.brace = 0 # {} + self.angle = 0 # <> + + def update(self, char: str): + """ + Update internal counters based on the given character. + """ + if char == '(': + self.paren += 1 + elif char == ')': + self.paren -= 1 + elif char == '[': + self.bracket += 1 + elif char == ']': + self.bracket -= 1 + elif char == '{': + self.brace += 1 + elif char == '}': + self.brace -= 1 + # Angle brackets < > are only treated as template delimiters + # when not inside (), [], or {} + elif char == '<' and self._in_top_level_of_other_brackets(): + self.angle += 1 + elif char == '>' and self.angle > 0 and self._in_top_level_of_other_brackets(): + self.angle -= 1 + + def _in_top_level_of_other_brackets(self): + """ + Check if not inside parentheses, square brackets, or braces (for correct template bracket recognition). + """ + return self.paren == 0 and self.bracket == 0 and self.brace == 0 + + def is_top_level(self): + """ + Check if completely at top level (all bracket counters are zero). + """ + return (self.paren == 0 and + self.bracket == 0 and + self.brace == 0 and + self.angle == 0) + + +def extract_m_def_statements(root_path): + """ + Scan all c files under root_path and extract all m.def(...) statements. + """ + results = [] + extensions = {'.hpp', '.cpp', '.h', '.cc'} + + # Regex: match m.def( ... ), supports multi-line + pattern = re.compile(r'm\.def\s*\(') + + for file_path in Path(root_path).rglob('*'): + if file_path.suffix.lower() not in extensions: + continue + if not file_path.is_file(): + continue + + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + except Exception as e: + print(f'Failed to read file {file_path}: {e}') + continue + + m_def_list = [] + lines = content.splitlines(keepends=True) + i = 0 + while i < len(lines): + line = lines[i] + if 'm.def(' in line: + # Found a potential starting line + start_i = i + # Check if it's a comment + stripped = line.lstrip() + if stripped.startswith('//') or stripped.startswith('/*'): + i += 1 + continue + + # Try to match the complete m.def(...) call + paren_count = 0 + j = i + found_start = False + while j < len(lines): + current_line = lines[j] + for k, char in enumerate(current_line): + if char == '(': + if not found_start and re.search(r'm\.def\s*\(', current_line[:k+1]): + found_start = True + if found_start: + paren_count += 1 + elif char == ')': + if found_start: + paren_count -= 1 + if paren_count == 0: + # Found complete statement + full_stmt = ''.join(lines[i:j+1]).rstrip() + m_def_list.append(full_stmt) + i = j + break + if paren_count <= 0 and found_start: + break + j += 1 + else: + pass + i += 1 + + if m_def_list: + results.append({ + 'file': str(file_path), + 'm_def_statements': m_def_list + }) + + return results + + +def parse_m_def_statement(m_def_str): + result = { + 'python_function_name': None, + 'num_args': 0, + 'default_args': {}, + 'is_lambda': False, + } + + # Extract top-level arguments + start = m_def_str.find('m.def(') + if start == -1: + raise ValueError(f'[{m_def_str}] Could not find m.def start position') + + paren_count = 0 + content_start = start + len('m.def(') + content_end = -1 + for i in range(content_start, len(m_def_str)): + ch = m_def_str[i] + if ch == '(': + paren_count += 1 + elif ch == ')': + if paren_count == 0: + content_end = i + break + else: + paren_count -= 1 + if content_end == -1: + raise ValueError(f'[{m_def_str}] m.def parentheses not closed') + + args_content = m_def_str[content_start:content_end] + + # Split arguments using BracketTracker + args_list = [] + current = [] + tracker = BracketTracker() + + for ch in args_content: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + args_list.append(''.join(current).strip()) + current = [] + else: + current.append(ch) + + if current: + args_list.append(''.join(current).strip()) + + if len(args_list) < 2: + raise ValueError(f'[{m_def_str}] m.def has insufficient arguments') + + # Extract Python function name + first = args_list[0].strip() + str_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"', first) + if str_match: + result['python_function_name'] = str_match.group(1) + else: + raise ValueError(f'[{m_def_str}] m.def first argument should be a string literal') + + cpp_func_part = args_list[1].strip() + if cpp_func_part.startswith('&'): + cpp_func_part = cpp_func_part[1:].strip() + + if cpp_func_part.startswith('['): + result['is_lambda'] = True + result['cpp_function_name'] = None + else: + if '::' in cpp_func_part: + cpp_func_name = cpp_func_part.split('::')[-1] + else: + cpp_func_name = cpp_func_part + + match = re.match(r'^([a-zA-Z_][a-zA-Z0-9_]*)', cpp_func_name) + if match: + result['cpp_function_name'] = match.group(1) + else: + result['cpp_function_name'] = cpp_func_name + + # Parse py::arg arguments + py_args = args_list[2:] + result['num_args'] = len(py_args) + + for idx, arg_expr in enumerate(py_args): + expr = arg_expr.strip() + # Find top-level '=' + eq_pos = -1 + p_depth = b_depth = br_depth = angle_depth = 0 + i = 0 + while i < len(expr): + ch = expr[i] + if ch == '(': + p_depth += 1 + elif ch == ')': + p_depth -= 1 + elif ch == '[': + b_depth += 1 + elif ch == ']': + b_depth -= 1 + elif ch == '{': + br_depth += 1 + elif ch == '}': + br_depth -= 1 + elif ch == '<' and p_depth == 0 and b_depth == 0 and br_depth == 0: + angle_depth += 1 + elif ch == '>' and angle_depth > 0 and p_depth == 0 and b_depth == 0 and br_depth == 0: + angle_depth -= 1 + elif ch == '=' and all(d == 0 for d in [p_depth, b_depth, br_depth, angle_depth]): + eq_pos = i + break + i += 1 + + if eq_pos != -1: + default_val = expr[eq_pos + 1:].strip() + if not default_val: + raise ValueError(f'[{expr}] Default value is empty (arg {idx})') + result['default_args'][idx] = default_val + + return result + + +def extract_cpp_signature_from_content(cpp_func_name, content): + """ + Search for the C++ function signature of cpp_func_name in the given file content. + """ + if not cpp_func_name: + return None + + # Build regex: match function starting with cpp_func_name (after word boundary) + # Note: function name may be preceded by return type (with templates, namespaces, etc.), followed by '(' + pattern = re.compile( + r'^\s*' # leading whitespace + r'([\w:\s*<&>,\[\]\(\)]+?)' # return type (non-greedy, allows templates, pointers, etc.) + r'\s+' # at least one space + r'\b' + re.escape(cpp_func_name) + r'\b' # function name (word boundary) + r'\s*\(', # optional whitespace + start of param list + re.MULTILINE + ) + + for match in pattern.finditer(content): + # Find '(' position after function name + paren_start = match.end() - 1 + if content[paren_start] != '(': + paren_start = content.find('(', match.end(0) - 1) + if paren_start == -1: + continue + + # From '(', match to corresponding ')' + paren_count = 0 + pos = paren_start + while pos < len(content): + ch = content[pos] + if ch == '(': + paren_count += 1 + elif ch == ')': + paren_count -= 1 + if paren_count == 0: + start_sig = match.start(1) + full_signature = content[start_sig:pos+1].strip() + return full_signature + pos += 1 + + return None + + +def parse_mdef_and_attach_cpp_signatures(item, func_index): + """ + Enhance item by parsing m.def and extracting C++ function signature from global index + """ + statements_with_parsed_signatures = [] + + for stmt in item['m_def_statements']: + parsed = parse_m_def_statement(stmt,) + cpp_func_name = parsed.get('cpp_function_name') + + cpp_sig = None + if cpp_func_name and cpp_func_name in func_index: + cpp_sig = func_index[cpp_func_name] + else: + if not parsed['is_lambda']: + print(f'Warning: C++ function "{cpp_func_name}" not found in any .cpp file') + + parsed['cpp_signature'] = cpp_sig + statements_with_parsed_signatures.append({ + 'raw': stmt, + 'parsed': parsed + }) + + return { + 'm_def_statements': statements_with_parsed_signatures + } + + +def parse_cpp_signature(cpp_sig): + """ + Parse a C++ function signature and extract return type, parameter types, and names. + """ + if not cpp_sig or not cpp_sig.strip(): + return None + + # Find function name: last identifier before '(' + paren_pos = cpp_sig.find('(') + if paren_pos == -1: + return None + + before_paren = cpp_sig[:paren_pos].strip() + if not before_paren: + return None + + # Function name is the last word in before_paren (may include templates like func) + tokens = before_paren.split() + if len(tokens) < 2: + return None + + # Heuristic: function name is usually the last token (may include <>) + func_name_part = tokens[-1] + return_type = ' '.join(tokens[:-1]).strip() + + # Now extract parameter list content + param_list_str = cpp_sig[paren_pos+1:cpp_sig.rfind(')')].strip() + parameters = [] + + if param_list_str and param_list_str != 'void': # 'void' means no parameters + # Split parameters (handle commas not inside templates/brackets) + param_decls = split_cpp_parameters(param_list_str) + for decl in param_decls: + decl = decl.strip() + if not decl: + continue + # Try to split type and name from right to left + param_info = parse_parameter_declaration(decl) + if param_info: + parameters.append(param_info) + + return { + 'return_type': return_type, + 'parameters': parameters, + 'num_parameters': len(parameters) + } + + +def split_cpp_parameters(param_str: str): + """ + Split a C++ parameter list string by top-level commas, + e.g., 'int a, std::vector b' → ['int a', 'std::vector b'] + """ + if not param_str.strip() or param_str == 'void': + return [] + params = [] + current = [] + tracker = BracketTracker() + + for ch in param_str: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + param = ''.join(current).strip() + if param: # Only add non-empty parameters + params.append(param) + current = [] + else: + current.append(ch) + + if current: + final_param = ''.join(current).strip() + if final_param: # Only add non-empty parameters + params.append(final_param) + return params + + +def parse_parameter_declaration(decl: str): + """ + Parse a single parameter declaration, e.g., 'const std::string& name' → {'type': 'const std::string&', 'name': 'name'} + Improved version that better handles template types. + """ + decl = decl.strip() + if not decl: + return None + + # Remove possible default value (starting from top-level '=') + tracker = BracketTracker() + eq_pos = -1 + for i, ch in enumerate(decl): + if ch in '()[]{}<>': + tracker.update(ch) + elif ch == '=' and tracker.is_top_level(): + eq_pos = i + break + + if eq_pos != -1: + decl = decl[:eq_pos].strip() + + # Now decl is 'type name' or just 'type' + # Instead of simple splitting, we'll use a more robust approach + # to find the parameter name + + # First, let's handle the case where there's no explicit parameter name + # (this sometimes happens in function declarations) + if not re.search(r'[a-zA-Z_][a-zA-Z0-9_]*$', decl): + # No parameter name found, just return the type + return { + 'type': decl, + 'name': None + } + + # Use bracket tracking to find where the type ends and name begins + tracker = BracketTracker() + name_start = -1 + + # Scan from the end to find the start of the parameter name + # We look for the first identifier that's outside all brackets + i = len(decl) - 1 + while i >= 0: + ch = decl[i] + + if ch in '()[]{}<>': + tracker.update(ch) + + # If we're at top level and find an identifier character + if tracker.is_top_level() and re.match(r'[a-zA-Z0-9_]', ch): + # Track back to find the start of this identifier + name_start = i + while name_start > 0 and re.match(r'[a-zA-Z0-9_]', decl[name_start - 1]): + name_start -= 1 + + # Check if this might be part of a type keyword (like 'int', 'bool', etc.) + potential_name = decl[name_start:i+1] + type_keywords = {'int', 'long', 'short', 'char', 'bool', 'float', 'double', + 'void', 'auto', 'const', 'static', 'volatile', 'mutable', + 'unsigned', 'signed'} + + # If it's not a type keyword and looks like a parameter name, use it + if (potential_name not in type_keywords and + re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', potential_name)): + break + + i -= 1 + + if name_start != -1 and i >= 0: + param_name = decl[name_start:i+1] + param_type = decl[:name_start].strip() + + # Clean up the type - remove trailing &, * and whitespace + param_type = param_type.rstrip('&* \t') + + return { + 'type': param_type, + 'name': param_name + } + + # Fallback: if we can't find a clear parameter name, just return the type + return { + 'type': decl, + 'name': None + } + + +def extract_cpp_signature_details(item): + """ + For each m.def entry in item, parse cpp_signature to extract return type and parameter details. + """ + statements_with_parsed_signatures = [] + for stmt_info in item['m_def_statements']: + parsed = stmt_info['parsed'] + cpp_sig = parsed.get('cpp_signature') + + cpp_params_info = None + if cpp_sig: + try: + cpp_params_info = parse_cpp_signature(cpp_sig) + except Exception as e: + print(f'Failed to parse C++ signature: {e}') + + parsed['cpp_parsed_signature'] = cpp_params_info + statements_with_parsed_signatures.append({ + 'raw': stmt_info['raw'], + 'parsed': parsed + }) + + return { + 'm_def_statements': statements_with_parsed_signatures + } + + +def cpp_type_to_python_type(cpp_type: str) -> str: + if not cpp_type: + return 'Any' + + original = cpp_type.strip() + if not original: + return 'Any' + + # Remove C++ specifiers that don't affect Python type + cleaned = re.sub(r'\b(static|inline|constexpr|thread_local|extern|mutable|const|volatile|endif)\b', '', original) + cleaned = cleaned.replace('&', '').replace('*', '').strip() + cleaned = re.sub(r'\s+', ' ', cleaned).strip() + + # Handle void + if cleaned == 'void': + return 'None' + + # Handle template types — ORDER MATTERS! Must come before internal type checks. + + # std::pair + if cleaned.startswith('std::pair<'): + inner = cleaned[10:-1].strip() # len('std::pair<') == 10 + args = split_template_args(inner) + if len(args) == 2: + t1 = cpp_type_to_python_type(args[0]) + t2 = cpp_type_to_python_type(args[1]) + return f'tuple[{t1}, {t2}]' + else: + print(f'Warning: std::pair with unexpected number of args: {cleaned}') + return 'Any' + + # std::tuple + if cleaned.startswith('std::tuple<'): + inner = cleaned[11:-1].strip() # len('std::tuple<') == 11 + args = split_template_args(inner) + py_types = [cpp_type_to_python_type(arg) for arg in args] + return f"tuple[{', '.join(py_types)}]" + + # std::vector + if cleaned.startswith('std::vector<'): + inner = cleaned[12:-1].strip() # len('std::vector<') == 12 + args = split_template_args(inner) + if len(args) == 1: + inner_py = cpp_type_to_python_type(args[0]) + return f'list[{inner_py}]' + else: + print(f'Warning: std::vector with unexpected args: {cleaned}') + return 'Any' + + # std::optional + if cleaned.startswith('std::optional<'): + inner = cleaned[14:-1].strip() # len('std::optional<') == 14 + args = split_template_args(inner) + if len(args) == 1: + inner_py = cpp_type_to_python_type(args[0]) + return f'Optional[{inner_py}]' + else: + print(f'Warning: std::optional with unexpected args: {cleaned}') + return 'Any' + + # std::string + if re.search(r'\bstd::string\b', original): + return 'str' + + # C-style strings: char*, const char*, char[], etc. + if re.search(r'\b(?:const\s+)?char\s*[\*\[]', original): + return 'str' + + # Boolean + if re.search(r'\bbool\b', cleaned): + return 'bool' + + # Integer types (including fixed-width and common aliases) + if re.search(r'\b(int|long|short|size_t|ssize_t|ptrdiff_t|' + r'int8_t|int16_t|int32_t|int64_t|' + r'uint8_t|uint16_t|uint32_t|uint64_t)\b', cleaned): + return 'int' + + # Floating-point + if re.search(r'\b(float|double|long\s+double)\b', cleaned): + return 'float' + + # torch::Tensor + if re.search(r'\btorch::Tensor\b', original): + return 'torch.Tensor' + + # Unrecognized type + print(f'Warning: Unrecognized C++ type: {original}') + return 'Any' + + +def split_template_args(template_args: str): + """ + Split template arguments, e.g., 'int, std::vector' → ['int', 'std::vector'] + """ + if not template_args.strip(): + return [] + args = [] + current = [] + tracker = BracketTracker() + + for ch in template_args: + if ch in '()[]{}<>': + tracker.update(ch) + if ch == ',' and tracker.is_top_level(): + args.append(''.join(current).strip()) + current = [] + else: + current.append(ch) + + if current: + args.append(''.join(current).strip()) + return args + + +def cpp_default_to_python_default(cpp_default: str): + """ + Convert C++ default value string to valid Python expression string. + """ + if not cpp_default: + return 'None' + + s = cpp_default.strip() + + # Handle string literals: 'bf16' → 'bf16' + # Match: starts and ends with unescaped double quotes + string_match = re.match(r'^"([^"\\]*(?:\\.[^"\\]*)*)"$', s) + if string_match: + return s + + # Handle boolean literals + if s == 'false': + return 'False' + if s == 'true': + return 'True' + + # Handle null-like values: nullptr, nullopt, NULL, etc. + if s in ('nullptr', 'NULL') or 'nullopt' in s: + return 'None' + + # Handle std::tuple({128, 128}) → (128, 128) + tuple_match = re.match(r'std::tuple\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) + if tuple_match: + inner = tuple_match.group(1) # {128, 128} + inner_py = inner.replace('{', '(').replace('}', ')') + return inner_py + + # Handle std::make_tuple(1, 2, 3) → (1, 2, 3) + make_tuple_match = re.match(r'std::make_tuple\s*\(\s*(.*?)\s*\)', s) + if make_tuple_match: + inner = make_tuple_match.group(1) + # Ensure it's a valid tuple even with one element: add comma if needed? + # But in C++ default args, it's usually multi-element, so we assume valid. + return f'({inner})' + + # Handle std::vector({1,2,3}) → [1, 2, 3] + vector_match = re.match(r'std::vector\s*<[^>]*>\s*\(\s*({.*?})\s*\)', s) + if vector_match: + inner = vector_match.group(1) + inner_py = inner.replace('{', '[').replace('}', ']') + return inner_py + + # Handle numeric literals: integers and floats + if re.match(r'^[+-]?\d+$', s): # integer + return s + if re.match(r'^[+-]?\d*\.\d+([eE][+-]?\d+)?$', s): # float + return s + + # Fallback: unrecognized → warn and return None + print(f'Warning: Unrecognized default value: {s}') + return 'None' + + +def generate_pyi_function(item_entry): + parsed = item_entry['parsed'] + py_name = parsed['python_function_name'] + + if parsed.get('is_lambda'): + return f'def {py_name}(*args, **kwargs) -> Any: ...' + + sig_info = parsed.get('cpp_parsed_signature') + default_args = parsed.get('default_args', {}) + + if not sig_info: + return f'def {py_name}(*args, **kwargs) -> Any: ...' + + return_type = cpp_type_to_python_type(sig_info['return_type']) + params = sig_info['parameters'] + num_params = len(params) + + # Build parameter list + param_lines = [] + for i in range(num_params): + param_info = params[i] if i < len(params) else {'type': 'Any', 'name': f'arg{i}'} + param_type = cpp_type_to_python_type(param_info['type']) + param_name = param_info['name'] or f'arg{i}' + + # Replace invalid Python identifiers (e.g., keywords) + if param_name in {'def', 'class', 'from', 'import', 'None', 'True', 'False'}: + param_name = f'{param_name}_' + + # Check for default value + if i in default_args: + cpp_default = default_args[i] + py_default = cpp_default_to_python_default(cpp_default) + param_str = f' {param_name}: {param_type} = {py_default}' + else: + param_str = f' {param_name}: {param_type}' + + param_lines.append(param_str) + + if param_lines: + params_block = ',\n'.join(param_lines) + func_def = f'def {py_name}(\n{params_block}\n) -> {return_type}: ...' + else: + func_def = f'def {py_name}() -> {return_type}: ...' + + return func_def + + +def generate_pyi_file_content(enhanced_results, module_name: str = 'my_module'): + function_decls = [] + has_optional = False + has_torch = False + has_numpy = False + + for item in enhanced_results: + for stmt in item['m_def_statements']: + try: + decl = generate_pyi_function(stmt) + function_decls.append(decl) + + if 'Optional[' in decl: + has_optional = True + if 'torch.Tensor' in decl: + has_torch = True + if 'numpy.ndarray' in decl or 'py::array' in str(stmt): + has_numpy = True + except Exception as e: + func_name = stmt['parsed'].get('python_function_name', 'unknown') + function_decls.append(f'# ERROR: failed to generate stub for {func_name}: {e}') + + imports = ['from typing import Any'] + if has_optional: + imports[0] += ', Optional' + + if has_torch: + imports.append('import torch') + if has_numpy: + imports.append('import numpy') + + lines = [f'# Stubs for module: {module_name}', ''] + lines.extend(imports) + lines.append('') + lines.append('') + + for decl in function_decls: + lines.append(decl) + lines.append('') + lines.append('') + + return '\n'.join(lines) + + +def generate_pyi_file(name, root, output_dir='.'): + func_index = build_cpp_function_index(root) + results = extract_m_def_statements(root) + + cpp_results = [] + for item in results: + enhanced_item = parse_mdef_and_attach_cpp_signatures(item, func_index) + cpp_item = extract_cpp_signature_details(enhanced_item) + cpp_results.append(cpp_item) + + pyi_content = generate_pyi_file_content(cpp_results, module_name=name) + + output_path = Path(output_dir) / f'{name}.pyi' + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(pyi_content) + + print(f'.pyi file generated: {output_path}') diff --git a/third_party/DeepGEMM/scripts/quick_plot_pm.py b/third_party/DeepGEMM/scripts/quick_plot_pm.py new file mode 100644 index 00000000..3aee8b86 --- /dev/null +++ b/third_party/DeepGEMM/scripts/quick_plot_pm.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +"""Plot a curated set of NCU PM metrics from an .ncu-rep report. + +Usage: + python scripts/quick_plot_pm.py [report.ncu-rep] + +By default the script saves a PNG next to the report. +With --interactive, it opens a Qt window instead. +""" + +import argparse +import csv +import io +import subprocess +from dataclasses import dataclass + +import matplotlib +import numpy as np + + +@dataclass(frozen=True) +class MetricSpec: + name: str + metric: str + kind: str + category: str + aliases: tuple[str, ...] = () + + +@dataclass(frozen=True) +class ResolvedMetricSpec: + name: str + metric: str + kind: str + category: str + + +@dataclass(frozen=True) +class MetricSeries: + name: str + metric: str + category: str + unit: str + values: tuple[float, ...] + + +CATEGORY_ORDER = [ + "Overview", + "SM", + "L1", + "L2", + "DRAM", + "Interconnect", +] + + +KIND_SUFFIXES = { + "pct_peak": [".avg.pct_of_peak_sustained_elapsed"], + "pct": [".pct", ".avg.pct_of_peak_sustained_elapsed"], + "avg": [".avg"], + "sum": [".sum"], + "avg_per_second": [".avg.per_second"], + "sum_per_second": [".sum.per_second"], + "avg_per_cycle_active": [".avg.per_cycle_active"], + "avg_per_cycle_elapsed": [".avg.per_cycle_elapsed"], + "sum_per_cycle_elapsed": [".sum.per_cycle_elapsed"], +} + + +# Curated from scripts/ncu-metrics.txt, with a few corrections against +# `ncu --query-metrics --chip gb100`: +# - Blocks launched uses `gr__ctas_launched_realtime` +# - SM active cycles uses `sm__cycles_active` +# - L2 throughput for GCC requests uses `lts__t_sector_throughput_srcunit_gcc` +# - C2C throughput uses `ctc__throughput` +# - NVLink RX metrics use the `NVLRX` domain +CURATED_METRICS = [ + MetricSpec("Blocks Launched", "FE_B.TriageCompute.gr__ctas_launched_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("Average Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "avg_per_cycle_elapsed", "Overview"), + MetricSpec("Total Blocks Active", "TPC.TriageCompute.tpc__ctas_active_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("Average CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "avg_per_cycle_elapsed", "Overview"), + MetricSpec("Total CGAs Active", "GPC_B.TriageCompute.gpc__cgas_active_realtime", "sum_per_cycle_elapsed", "Overview"), + MetricSpec("SM Active Cycles", "TPC.TriageCompute.sm__cycles_active", "avg", "SM"), + MetricSpec("Executed IPC Active", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_active", "SM"), + MetricSpec("Executed IPC Elapsed", "TPC.TriageCompute.sm__inst_executed_realtime", "avg_per_cycle_elapsed", "SM"), + MetricSpec("SM Throughput", "TPC.TriageCompute.sm__inst_executed_realtime", "pct_peak", "SM"), + MetricSpec("SM ALU Pipe Throughput", "TPC.TriageCompute.sm__inst_executed_pipe_alu_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Pipe Throughput", "TPC.TriageCompute.sm__pipe_fma_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Heavy Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmaheavy_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM FMA Light Pipe Throughput", "TPC.TriageCompute.sm__pipe_fmalite_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM Tensor Pipe Throughput", "TPC.TriageCompute.sm__pipe_tensor_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM TMEM Pipe Throughput", "SM_A.TriageCompute.sm__mem_tensor_cycles_active_realtime", "pct_peak", "SM"), + MetricSpec("SM Uniform Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_uniform_realtime", "pct_peak", "SM"), + MetricSpec("SM XU Pipe Throughput", "SM_A.TriageCompute.sm__inst_executed_pipe_xu_realtime", "pct_peak", "SM"), + MetricSpec("L1 Throughput", "SM_A.TriageCompute.l1tex__throughput", "pct_peak", "L1"), + MetricSpec("L1 Sectors", "SM_B.TriageCompute.l1tex__t_sectors", "sum", "L1"), + MetricSpec("L1 Hit Rate", "SM_B.TriageCompute.l1tex__t_sector_hit_rate", "pct", "L1"), + MetricSpec("L1 Lookup Hit", "SM_B.TriageCompute.l1tex__t_sectors_lookup_hit", "sum", "L1"), + MetricSpec("L1 Lookup Miss", "SM_B.TriageCompute.l1tex__t_sectors_lookup_miss", "sum", "L1"), + MetricSpec("L1 Wavefronts (Data)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts", "avg", "L1"), + MetricSpec("L1 Wavefronts (LGDS)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_lgds", "avg", "L1"), + MetricSpec("L1 Wavefronts (Shared)", "SM_A.TriageCompute.l1tex__data_pipe_lsu_wavefronts_mem_shared", "avg", "L1"), + MetricSpec("L2 Throughput", "LTS.TriageCompute.lts__throughput", "pct_peak", "L2"), + MetricSpec("L2 Throughput for L1 Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_tex", "pct_peak", "L2"), + MetricSpec("L2 Throughput for GCC Requests", "LTS.TriageCompute.lts__t_sector_throughput_srcunit_gcc", "pct_peak", "L2"), + MetricSpec("L2 Throughput to DRAM", "LTS.TriageCompute.lts__t_sector_throughput_srcnode_fbp", "pct_peak", "L2"), + MetricSpec("SysL2 Throughput to Peer Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_peer", "pct_peak", "L2"), + MetricSpec("SysL2 Throughput to System Memory", "SYSLTS.TriageCompute.syslts__t_sector_throughput_aperture_sysmem", "pct_peak", "L2"), + MetricSpec("L2 Hit Rate", "LTS.TriageCompute.lts__average_t_sector_hit_rate_realtime", "pct", "L2"), + MetricSpec("L2 Hit Rate From L1", "LTS.TriageCompute.lts__average_t_sector_hit_rate_srcunit_tex_realtime", "pct", "L2"), + MetricSpec("DRAM Frequency", "FBSP.TriageCompute.dram__cycles_elapsed", "avg_per_second", "DRAM"), + MetricSpec("DRAM Throughput", "FBSP.TriageCompute.dram__throughput", "pct_peak", "DRAM"), + MetricSpec("DRAM Read Throughput", "FBSP.TriageCompute.dram__read_throughput", "pct_peak", "DRAM"), + MetricSpec("DRAM Write Throughput", "FBSP.TriageCompute.dram__write_throughput", "pct_peak", "DRAM"), + MetricSpec("C2C Throughput", "TriageCompute.ctc__throughput", "pct_peak", "Interconnect", aliases=("TriageCompute.ctx__throughput",)), + MetricSpec("NVLink Transmitted Throughput", "NVLTX.TriageCompute.nvltx__bytes", "pct_peak", "Interconnect"), + MetricSpec("NVLink Received Throughput", "NVLRX.TriageCompute.nvlrx__bytes", "pct_peak", "Interconnect"), + MetricSpec("NVLink Transmitted Bandwidth", "NVLTX.TriageCompute.nvltx__bytes", "sum_per_second", "Interconnect"), + MetricSpec("NVLink Received Bandwidth", "NVLRX.TriageCompute.nvlrx__bytes", "sum_per_second", "Interconnect"), + MetricSpec("PCIe Throughput", "PCI.TriageCompute.pcie__throughput", "pct_peak", "Interconnect"), + MetricSpec("PCIe Read Bandwidth", "PCI.TriageCompute.pcie__read_bytes", "sum_per_second", "Interconnect"), + MetricSpec("PCIe Write Bandwidth", "PCI.TriageCompute.pcie__write_bytes", "sum_per_second", "Interconnect"), +] + + +def _run_csv_command(command, timeout): + result = subprocess.run(command, capture_output=True, text=True, timeout=timeout) + if result.returncode != 0 and not result.stdout: + return None + reader = csv.reader(io.StringIO(result.stdout)) + return list(reader) + + +def _query_available_metrics(chip): + result = subprocess.run( + ["ncu", "--query-metrics", "--chip", chip], + capture_output=True, + text=True, + timeout=30, + ) + if result.returncode != 0: + raise RuntimeError(result.stderr.strip() or f"failed to query metrics for chip {chip}") + + metrics = set() + for line in result.stdout.splitlines(): + parts = line.split() + if not parts: + continue + token = parts[0] + if "__" not in token: + continue + metrics.add(token) + return metrics + + +def _metric_candidates(metric): + candidates = [metric] + marker = ".TriageCompute." + if marker in metric: + candidates.append(metric.split(marker, 1)[1]) + return candidates + + +def resolve_metric_specs(chip): + available = _query_available_metrics(chip) + resolved = [] + missing = [] + + for spec in CURATED_METRICS: + candidates = [] + for metric in (spec.metric, *spec.aliases): + candidates.extend(_metric_candidates(metric)) + + actual_metric = next((metric for metric in candidates if metric in available), None) + if actual_metric is None: + missing.append(spec) + continue + + resolved.append(ResolvedMetricSpec(spec.name, actual_metric, spec.kind, spec.category)) + + return resolved, missing + + +def _parse_metric_values(raw): + if not raw or raw == "no data": + return () + + try: + if raw.startswith("(") and raw.endswith(")"): + rest = raw[1:-1] + return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip()) + if " (" in raw: + _agg, rest = raw.split(" (", 1) + rest = rest.rstrip(")") + return tuple(float(v.strip().replace(",", "")) for v in rest.split(";") if v.strip()) + return (float(raw.replace(",", "")),) + except ValueError: + return () + + +def _probe_metric_series(report, metric_name): + rows = _run_csv_command( + [ + "ncu", + "--import", + report, + "--page", + "raw", + "--csv", + "--metrics", + metric_name, + "--print-metric-instances", + "values", + ], + timeout=60, + ) + if not rows or len(rows) < 3 or len(rows[0]) <= 11: + return None + + header, units, row = rows[0], rows[1], rows[2] + unit = units[11] if len(units) > 11 else "" + raw = row[11] if len(row) > 11 else "" + values = _parse_metric_values(raw) + return header[11], unit, values + + +def collect_metric_series(report, resolved_specs): + collected = [] + skipped = [] + + for spec in resolved_specs: + series = None + for suffix in KIND_SUFFIXES[spec.kind]: + probe = _probe_metric_series(report, f"{spec.metric}{suffix}") + if probe is None: + continue + full_metric, unit, values = probe + if len(values) > 1: + series = MetricSeries(spec.name, full_metric, spec.category, unit, values) + break + + if series is None: + skipped.append(spec) + continue + + collected.append(series) + + return collected, skipped + + +def _format_value(value): + if value == 0: + return "0" + abs_value = abs(value) + if abs_value >= 1e12: + return f"{value / 1e12:.2f} T" + if abs_value >= 1e9: + return f"{value / 1e9:.2f} G" + if abs_value >= 1e6: + return f"{value / 1e6:.2f} M" + if abs_value >= 1e3: + return f"{value / 1e3:.2f} K" + if abs_value >= 1: + return f"{value:.1f}" + return f"{value:.2f}" + + +def _format_with_unit(value, unit): + if not unit: + return _format_value(value) + return f"{_format_value(value)} {unit}" + + +def plot_pm(report, metrics, save=False): + """Plot curated PM metrics as shared-x subplots in a light theme.""" + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + + if not metrics: + print("No curated metrics had time-series data in the report.") + return + + bg_fig = "#ffffff" + bg_row = "#f6f8fb" + text_primary = "#1f2937" + text_secondary = "#6b7280" + text_header = "#111827" + grid_color = "#d7deea" + border = "#c7d0dd" + + wave_colors = { + "Overview": "#7c8aa5", + "SM": "#4f87c2", + "L1": "#2f9d8f", + "L2": "#dd8452", + "DRAM": "#c95d63", + "Interconnect": "#8c6bb1", + } + + category_rank = {category: index for index, category in enumerate(CATEGORY_ORDER)} + metrics = sorted(metrics, key=lambda item: (category_rank.get(item.category, 99), item.name)) + + row_h = 0.55 + label_w = 3.6 + plot_w = 14.0 + fig_w = label_w + plot_w + fig_h = row_h * len(metrics) + 0.6 + + fig = plt.figure(figsize=(fig_w, fig_h), facecolor=bg_fig) + gs = GridSpec( + len(metrics), + 1, + figure=fig, + left=label_w / fig_w, + right=0.97, + top=1 - 0.45 / fig_h, + bottom=0.35 / fig_h, + hspace=0.18, + ) + axes = [fig.add_subplot(gs[i, 0]) for i in range(len(metrics))] + + prev_category = None + for idx, metric in enumerate(metrics): + ax = axes[idx] + values = np.array(metric.values) + x = np.arange(len(values)) + wave_color = wave_colors.get(metric.category, "#5b9bd5") + + ax.set_facecolor(bg_row) + ax.fill_between(x, values, alpha=0.35, color=wave_color, linewidth=0) + ax.plot(x, values, linewidth=0.8, color=wave_color) + + ax.set_xlim(0, len(values) - 1) + if metric.unit == "%": + ax.set_ylim(0, 100) + else: + ymax = np.max(values) + ax.set_ylim(0, ymax * 1.15 if ymax > 0 else 1) + + ax.grid(True, axis="both", color=grid_color, linewidth=0.5, alpha=0.85) + ax.tick_params(axis="both", colors=text_secondary, labelsize=6, length=0) + + if idx < len(metrics) - 1: + ax.tick_params(axis="x", labelbottom=False) + else: + ax.set_xlabel("Sample Index", fontsize=8, color=text_secondary) + + ymin_v, ymax_v = ax.get_ylim() + ax.set_yticks([ymin_v, ymax_v]) + ax.set_yticklabels([_format_value(ymin_v), _format_value(ymax_v)], fontsize=6, color=text_secondary) + + peak = np.max(values) + ax.text( + 1.005, + 0.5, + _format_with_unit(peak, metric.unit), + transform=ax.transAxes, + fontsize=7, + color=text_secondary, + va="center", + ha="left", + family="monospace", + ) + + for spine in ax.spines.values(): + spine.set_color(border) + spine.set_linewidth(0.5) + + if metric.category != prev_category: + cat_y = ax.get_position().y1 + 0.008 + fig.text( + 0.005, + cat_y, + f" {metric.category}", + fontsize=8.5, + fontweight="bold", + color=text_header, + va="bottom", + family="sans-serif", + transform=fig.transFigure, + bbox=dict(boxstyle="square,pad=0.15", facecolor="#e9eef5", edgecolor="none"), + ) + prev_category = metric.category + + label_y = (ax.get_position().y0 + ax.get_position().y1) / 2 + fig.text( + label_w / fig_w - 0.012, + label_y, + metric.name, + fontsize=7.5, + color=text_primary, + va="center", + ha="right", + family="sans-serif", + transform=fig.transFigure, + ) + + fig.text( + 0.5, + 1 - 0.15 / fig_h, + f"PM Sampling - {report}", + fontsize=11, + fontweight="bold", + color=text_header, + ha="center", + va="top", + family="sans-serif", + transform=fig.transFigure, + ) + + if save: + out_path = report.replace(".ncu-rep", ".pm_sampling.png") + fig.savefig(out_path, dpi=150, facecolor=bg_fig, bbox_inches="tight", pad_inches=0.2) + print(f"Saved: {out_path}") + plt.close(fig) + else: + plt.show() + + +def main(): + parser = argparse.ArgumentParser(description="NCU PM Sampling plotter") + parser.add_argument("report", nargs="?", default="mega-moe-kk.3.ncu-rep", help="Path to .ncu-rep file") + parser.add_argument("--chip", default="gb100", help="Chip name used for `ncu --query-metrics`") + parser.add_argument("--interactive", action="store_true", help="Open an interactive Qt window instead of saving a PNG") + args = parser.parse_args() + + if args.interactive: + matplotlib.use("QtAgg") + else: + matplotlib.use("Agg") + + resolved_specs, missing_specs = resolve_metric_specs(args.chip) + if missing_specs: + print(f"Skipped {len(missing_specs)} curated metrics not available on {args.chip}.") + for spec in missing_specs: + print(f" missing: {spec.name} -> {spec.metric}") + + metric_series, skipped_specs = collect_metric_series(args.report, resolved_specs) + if skipped_specs: + print(f"Skipped {len(skipped_specs)} curated metrics with no time-series data in {args.report}.") + for spec in skipped_specs: + print(f" no series: {spec.name} -> {spec.metric}") + + plot_pm(args.report, metric_series, save=not args.interactive) + + +if __name__ == "__main__": + main() diff --git a/third_party/DeepGEMM/scripts/run_ncu_mega_moe.sh b/third_party/DeepGEMM/scripts/run_ncu_mega_moe.sh new file mode 100755 index 00000000..4324575c --- /dev/null +++ b/third_party/DeepGEMM/scripts/run_ncu_mega_moe.sh @@ -0,0 +1,89 @@ +#!/bin/bash + +set -e + +# parse num-processes, output_dir and separate python args +num_processes=8 +output_dir=work +python_args=() +for ((arg_idx = 1; arg_idx <= $#; ++arg_idx)); do + arg="${!arg_idx}" + case "$arg" in + --num-processes) + python_args+=("$arg") + if ((arg_idx < $#)); then + ((arg_idx++)) + num_processes="${!arg_idx}" + python_args+=("$num_processes") + fi + ;; + -h|--help) + echo "Usage: $0 [--num-processes N] [--output DIR] [python args...]" + exit 0 + ;; + --num-processes=*) + num_processes="${arg#*=}" + python_args+=("$arg") + ;; + -o|--output) + if ((arg_idx < $#)); then + ((arg_idx++)) + output_dir="${!arg_idx}" + fi + ;; + --output=*) + output_dir="${arg#*=}" + ;; + *) + python_args+=("$arg") + ;; + esac +done + +echo "Python Args: ${python_args[*]}" +echo "Num Processes: $num_processes" +echo "Output Dir: $output_dir" +mkdir -p $output_dir + +export DG_JIT_WITH_LINEINFO=1 # for source counters + +echo "Warm up JIT cache" +python tests/test_mega_moe.py --ncu-profile-only "${python_args[@]}" + +sleep 2 + +ncu_args=( + --config-file off + --force-overwrite + --kernel-name sm100_fp8_fp4_mega_moe_impl + --import-source yes + --replay-mode application + --section PmSampling + --section SourceCounters + --rule LocalMemoryUsage + --launch-skip 0 + --launch-count 1 + --lockstep-kernel-launch + --communicator tcp + --clock-control none + --pm-sampling-interval 1000 + --pm-sampling-max-passes 1 + --disable-pm-warp-sampling + --communicator-tcp-num-peers "$num_processes" + --kill yes + --app-replay-buffer memory +) + +echo "Run Job" + +for ((i = 0; i < num_processes; ++i)); do + ncu ${ncu_args[@]} -o "${output_dir%/}/mega-moe.$i" \ + python tests/test_mega_moe.py \ + --local-rank-idx=$i \ + --ncu-profile-only \ + "${python_args[@]}" & +done + +echo "Waiting" +wait +echo "Done" diff --git a/third_party/DeepGEMM/setup.py b/third_party/DeepGEMM/setup.py new file mode 100644 index 00000000..c4d74ae9 --- /dev/null +++ b/third_party/DeepGEMM/setup.py @@ -0,0 +1,214 @@ +import ast +import os +import re +import shutil +import setuptools +import subprocess +import sys +import torch +import platform +import urllib +import urllib.error +import urllib.request +from setuptools import find_packages +from setuptools.command.build_py import build_py +from packaging.version import parse +from pathlib import Path +from torch.utils.cpp_extension import CUDAExtension, CUDA_HOME +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from scripts.generate_pyi import generate_pyi_file + + +DG_SKIP_CUDA_BUILD = int(os.getenv('DG_SKIP_CUDA_BUILD', '0')) == 1 +DG_FORCE_BUILD = int(os.getenv('DG_FORCE_BUILD', '0')) == 1 +DG_USE_LOCAL_VERSION = int(os.getenv('DG_USE_LOCAL_VERSION', '1')) == 1 +DG_JIT_USE_RUNTIME_API = int(os.environ.get('DG_JIT_USE_RUNTIME_API', '0')) == 1 + +# Compiler flags +cxx_flags = ['-std=c++17', '-O3', '-fPIC', '-Wno-psabi', '-Wno-deprecated-declarations', + f'-D_GLIBCXX_USE_CXX11_ABI={int(torch.compiled_with_cxx11_abi())}'] +if DG_JIT_USE_RUNTIME_API: + cxx_flags.append('-DDG_JIT_USE_RUNTIME_API') + +# Sources +current_dir = os.path.dirname(os.path.realpath(__file__)) +sources = ['csrc/python_api.cpp'] +build_include_dirs = [ + f'{CUDA_HOME}/include', + f'{CUDA_HOME}/include/cccl', + 'deep_gemm/include', + 'third-party/cutlass/include', + 'third-party/fmt/include', +] +build_libraries = ['cudart', 'nvrtc'] +build_library_dirs = [f'{CUDA_HOME}/lib64'] +third_party_include_dirs = [ + 'third-party/cutlass/include/cute', + 'third-party/cutlass/include/cutlass', +] + +# Release +base_wheel_url = 'https://github.com/DeepSeek-AI/DeepGEMM/releases/download/{tag_name}/{wheel_name}' + + +def get_package_version(): + with open(Path(current_dir) / 'deep_gemm' / '__init__.py', 'r') as f: + version_match = re.search(r'^__version__\s*=\s*(.*)$', f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + + revision = '' + if DG_USE_LOCAL_VERSION: + # noinspection PyBroadException + try: + status_cmd = ['git', 'status', '--porcelain'] + status_output = subprocess.check_output(status_cmd).decode('ascii').strip() + if status_output: + print(f'Warning: Git working directory is not clean. Uncommitted changes:\n{status_output}') + assert False, 'Git working directory is not clean' + + cmd = ['git', 'rev-parse', '--short', 'HEAD'] + revision = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() + except Exception: + revision = '+local' + return f'{public_version}{revision}' + + +def get_platform(): + if sys.platform.startswith('linux'): + return f'linux_{platform.uname().machine}' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + + +def get_wheel_url(): + torch_version = parse(torch.__version__) + torch_version = f'{torch_version.major}.{torch_version.minor}' + python_version = f'cp{sys.version_info.major}{sys.version_info.minor}' + platform_name = get_platform() + deep_gemm_version = get_package_version() + cxx11_abi = int(torch._C._GLIBCXX_USE_CXX11_ABI) + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + cuda_version = parse(torch.version.cuda) + cuda_version = f'{cuda_version.major}' + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f'deep_gemm-{deep_gemm_version}+cu{cuda_version}-torch{torch_version}-cxx11abi{cxx11_abi}-{python_version}-{platform_name}.whl' + wheel_url = base_wheel_url.format(tag_name=f'v{deep_gemm_version}', wheel_name=wheel_filename) + return wheel_url, wheel_filename + + +def get_ext_modules(): + if DG_SKIP_CUDA_BUILD: + return [] + + return [CUDAExtension(name='deep_gemm._C', + sources=sources, + include_dirs=build_include_dirs, + libraries=build_libraries, + library_dirs=build_library_dirs, + extra_compile_args=cxx_flags)] + + +class CustomBuildPy(build_py): + def run(self): + # First, prepare the include directories + self.prepare_includes() + + # Second, make clusters' cache setting default into `envs.py` + self.generate_default_envs() + + # Third, generate and copy .pyi file to build root directory + self.generate_pyi_file() + + # Finally, run the regular build + build_py.run(self) + + def generate_pyi_file(self): + generate_pyi_file(name='_C', root='./csrc', output_dir='./stubs') + pyi_source = os.path.join(current_dir, 'stubs', '_C.pyi') + pyi_target = os.path.join(self.build_lib, 'deep_gemm', '_C.pyi') + + if os.path.exists(pyi_source): + print(f"Copying .pyi file from {pyi_source} to {pyi_target}") + os.makedirs(os.path.dirname(pyi_target), exist_ok=True) + shutil.copy2(pyi_source, pyi_target) + else: + print(f"Warning: .pyi file not found at {pyi_source}") + + def generate_default_envs(self): + code = '# Pre-installed environment variables\n' + code += 'persistent_envs = dict()\n' + for name in ('DG_JIT_CACHE_DIR', 'DG_JIT_PRINT_COMPILER_COMMAND', 'DG_JIT_CPP_STANDARD'): + code += f"persistent_envs['{name}'] = '{os.environ[name]}'\n" if name in os.environ else '' + + with open(os.path.join(self.build_lib, 'deep_gemm', 'envs.py'), 'w') as f: + f.write(code) + + def prepare_includes(self): + # Create temporary build directory instead of modifying package directory + build_include_dir = os.path.join(self.build_lib, 'deep_gemm/include') + os.makedirs(build_include_dir, exist_ok=True) + + # Copy third-party includes to the build directory + for d in third_party_include_dirs: + dirname = d.split('/')[-1] + src_dir = os.path.join(current_dir, d) + dst_dir = os.path.join(build_include_dir, dirname) + + # Remove existing directory if it exists + if os.path.exists(dst_dir): + shutil.rmtree(dst_dir) + + # Copy the directory + shutil.copytree(src_dir, dst_dir) + + +class CachedWheelsCommand(_bdist_wheel): + def run(self): + if DG_FORCE_BUILD or DG_USE_LOCAL_VERSION: + return super().run() + + wheel_url, wheel_filename = get_wheel_url() + print(f'Try to download wheel from URL: {wheel_url}') + # noinspection PyBroadException + try: + with urllib.request.urlopen(wheel_url, timeout=1) as response: + with open(wheel_filename, 'wb') as out_file: + data = response.read() + out_file.write(data) + + # Make the archive + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f'{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}' + wheel_path = os.path.join(self.dist_dir, archive_basename + '.whl') + os.rename(wheel_filename, wheel_path) + except (urllib.error.HTTPError, urllib.error.URLError): + print('Precompiled wheel not found. Building from source...') + # If the wheel could not be downloaded, build from source + super().run() + + +if __name__ == '__main__': + # noinspection PyTypeChecker + setuptools.setup( + name='deep_gemm', + version=get_package_version(), + packages=find_packages('.'), + package_data={ + 'deep_gemm': [ + 'include/deep_gemm/**/*', + 'include/cute/**/*', + 'include/cutlass/**/*', + ] + }, + ext_modules=get_ext_modules(), + zip_safe=False, + cmdclass={ + 'build_py': CustomBuildPy, + 'bdist_wheel': CachedWheelsCommand, + }, + ) diff --git a/third_party/DeepGEMM/tests/generators.py b/third_party/DeepGEMM/tests/generators.py new file mode 100644 index 00000000..989e984e --- /dev/null +++ b/third_party/DeepGEMM/tests/generators.py @@ -0,0 +1,407 @@ +import enum +import random +import torch +from typing import Generator, List, Optional, Tuple + +from deep_gemm.testing import get_arch_major +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, + per_token_cast_to_fp4, transpose_packed_fp4, + get_mk_alignment_for_contiguous_layout, + set_mk_alignment_for_contiguous_layout +) + + +class KernelType(enum.Enum): + Kernel1D1D = 0 + Kernel1D2D = 1 + KernelNoSF = 2 + + def is_1d1d(self): + return self.value == 0 + + def is_1d2d(self): + return self.value == 1 + + def is_nosf(self): + return self.value == 2 + + +class MajorTypeAB(enum.Enum): + KMajor = 0 + MNMajor = 1 + + def is_k_major(self): + return self.value == 0 + + def is_mn_major(self): + return self.value == 1 + + +class QuantConfig: + _legacy_quant_config = (128, 128, False, False) + + def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config): + self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value + + def print(self): + print(f' > Testing with gran_k_a={self.gran_k_a}, gran_k_b={self.gran_k_b}, ' + f'is_fp4_a={self.is_fp4_a}, is_fp4_b={self.is_fp4_b}') + + def is_legacy(self) -> bool: + return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config + + def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]: + recipe, recipe_a, recipe_b = None, None, None + if self.is_legacy(): + recipe = (1, 1, 128) if is_wgrad else None + else: + recipe_a = (1, self.gran_k_a) + recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b) + return recipe, recipe_a, recipe_b + + def max_diff(self) -> float: + if self.is_fp4_a and self.is_fp4_b: + return 0.02 + if self.is_fp4_a or self.is_fp4_b: + return 0.01 + return 0.001 + + @staticmethod + def get_list_from_dtype(dtype: torch.dtype) -> List: + if dtype == torch.bfloat16: + return [None] + quant_config_list = [QuantConfig()] + if get_arch_major() == 10: + quant_config_list.append(QuantConfig((128, 32, False, True))) + return quant_config_list + + +def reset_seed(seed: int = 0): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def get_ue8m0_usage(kernel_type: KernelType) -> bool: + if get_arch_major() == 9: + return False + return kernel_type.is_1d1d() + + +def get_kernel_types(dtype: torch.dtype) -> tuple: + if dtype == torch.bfloat16: + return (KernelType.KernelNoSF, ) + + return (KernelType.Kernel1D2D, ) if get_arch_major() == 9 else (KernelType.Kernel1D1D, ) + + +def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator: + for major_a in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): + for major_b in (MajorTypeAB.KMajor, MajorTypeAB.MNMajor): + if major_a.is_mn_major() and not allow_a_mn_major: + continue + if major_b.is_mn_major() and not allow_b_mn_major: + continue + yield major_a, major_b + + +def get_psum_layout_usage() -> tuple: + return True, False + + +def enumerate_normal(dtype: torch.dtype) -> Generator: + assert dtype in (torch.float8_e4m3fn, torch.bfloat16) + + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + fp32_output_nk = [(256, 7168), (129280, 7168)] + bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)] + m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ] + nk_list = list(bf16_output_nk) + + # Only BF16 GEMM needs FP32 outputs + if dtype == torch.bfloat16: + nk_list += fp32_output_nk + + for kernel_type in get_kernel_types(dtype): + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + reset_seed() + + # Forward + for m in m_fwd_list: + for i in range(len(nk_list)): + n, k = nk_list[i] + out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float + yield kernel_type, quant_config, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype + + # Backward + for m in m_bwd_list: + for n, k in nk_list: + override_major = MajorTypeAB.MNMajor + override_kernel_type = kernel_type + if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: + override_major = MajorTypeAB.KMajor + override_kernel_type = KernelType.Kernel1D1D + yield kernel_type, quant_config, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, True, torch.float # Wgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad + + +def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + m_group_list = [(4, 8192), (8, 4096)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] + for kernel_type in get_kernel_types(dtype): + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + reset_seed() + for num_groups, expected_m_per_group in m_group_list: + for n, k in n_k_list: + for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn): + yield kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout + + +def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + max_m = 4096 + m_group_list = [(32, 192), (6, 1024), (32, 20), (6, 20)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] + for kernel_type in get_kernel_types(dtype): + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + reset_seed() + for num_groups, m in m_group_list: + for n, k in n_k_list: + yield kernel_type, quant_config, num_groups, max_m, m, n, k, use_psum_layout + + +def enumerate_k_grouped_contiguous(dtype: torch.dtype): + gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128) + # Only K-major is supported for SM90 FP8 + major_a, major_b = (MajorTypeAB.KMajor, MajorTypeAB.KMajor) if get_arch_major() == 9 and dtype == torch.float8_e4m3fn \ + else (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + # Must with FP32 accumulation and 1D1D kernels + for num_groups, m, n, expected_k_per_group in (( 4, 4096, 7168, 8192), ( 4, 7168, 2048, 8192), # EP64 + ( 8, 4096, 7168, 4096), ( 8, 7168, 2048, 4096), # EP32 + (16, 4096, 7168, 2048), (16, 7168, 2048, 2048)): # EP16 + if dtype == torch.bfloat16: + ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), get_mk_alignment_for_contiguous_layout()) for _ in range(num_groups)] + yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group + else: + for gran_k in gran_k_list: + set_mk_alignment_for_contiguous_layout(gran_k) + ks = [align(int(expected_k_per_group * random.uniform(0.7, 1.3)), gran_k) for _ in range(num_groups)] + yield num_groups, m, n, major_a, major_b, ks, expected_k_per_group, gran_k + + +def enumerate_sf_layout(): + gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128) + for use_ue8m0 in (False, True): + for with_transpose in (True, False): + for mn in (4096, 4097, 8192): + for k in (128, 7168, 7296): + for num_groups in (1, 2, 4): + for gran_k in gran_k_list: + set_mk_alignment_for_contiguous_layout(gran_k) + yield mn, k, with_transpose, use_ue8m0, num_groups, gran_k + + +def enumerate_k_grouped_sf_layout(): + gran_k_list = (128, ) if get_arch_major() == 9 else (32, 128) + for mn in (4096, 7168): + for num_groups, avg_k in ((16, 2048), (8, 4096), (72, 384), (128, 256)): + for gran_k in gran_k_list: + set_mk_alignment_for_contiguous_layout(gran_k) + ks = [align(int(random.uniform(0.7, 1.3) * avg_k), gran_k) for _ in range(num_groups)] + yield mn, ks, num_groups, gran_k + + +def enumerate_transpose(): + for mn in (64, 4096, 16384): + for delta in (0, 101, 202, 303): + for k in (128, 1024, 4096, 9984, 16384): + yield mn + delta, k + + +def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + if is_fp4: + x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + return x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1]) + else: + x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + return x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1]) + + +def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + num_groups, mn, k = x.size() + if is_fp4: + x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.int8) if major.is_k_major() else \ + torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.int8), + torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1]) + return x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1]) + else: + x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \ + else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + return x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1]) + + +def generate_normal(m: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + accumulate: bool, out_dtype: torch.dtype, + kernel_type: KernelType, + use_ue8m0: bool = False, use_bf16: bool = False, + quant_config: Optional[QuantConfig] = None): + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) + d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ + torch.empty((m, n), device='cuda', dtype=out_dtype) + c = d if accumulate else None + ref_d = (a.float() @ b.float().t() + (c if accumulate else 0)).to(out_dtype) + + if use_bf16: + a = a if major_a.is_k_major() else a.T.contiguous().T + b = b if major_b.is_k_major() else b.T.contiguous().T + return a, b, c, d, ref_d + + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, + use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate)) + + return a, b, c, d, ref_d + + +def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, + major_a: MajorTypeAB, major_b: MajorTypeAB, + use_ue8m0: bool = False, use_bf16: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): + actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] + aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] + m = sum(aligned_ms) + + a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \ + else torch.empty(m, device='cuda', dtype=torch.int32) + d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) + + start = 0 + for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): + actual_end = start + actual_m + aligned_end = start + aligned_m + if use_psum_layout: + grouped_layout[i] = actual_end + else: + grouped_layout[start: actual_end] = i + grouped_layout[actual_end: aligned_end] = -1 + a[actual_end: aligned_end] = 0 + ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t() + start = aligned_end + + if use_bf16: + b = b if major_b.is_k_major() else b.mT.contiguous().mT + return m, a, b, grouped_layout, d, ref_d + + assert major_a.is_k_major() + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) + + return m, a, b, grouped_layout, d, ref_d + + +def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor): + num_groups, max_m, _ = x.size() + x_psum = torch.empty_like(x).view(num_groups * max_m, -1) + last_psum_m = 0 + for i in range(num_groups): + x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m] + last_psum_m = align(psum_m[i], get_mk_alignment_for_contiguous_layout()) + return x_psum + + +def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, + use_ue8m0: bool = False, use_bf16: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): + a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) + b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) + d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) + ref_d = torch.einsum('gmk,gnk->gmn', a, b) + + masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + for j in range(num_groups): + masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout())) + masked_m[j] + assert masked_m.amax().item() <= max_m + + if use_bf16: + return a, b, masked_m, psum_m, d, ref_d + + quant_config = QuantConfig() if quant_config is None else quant_config + a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) + + return a, b, masked_m, psum_m, d, ref_d + + +def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], + use_ue8m0: bool = False, use_bf16: bool = False, gran_k = 128): + assert get_mk_alignment_for_contiguous_layout() % gran_k == 0 + k = sum(ks) + + a = torch.randn((k, m), device='cuda', dtype=torch.bfloat16) + b = torch.randn((k, n), device='cuda', dtype=torch.bfloat16) + c = torch.randn((num_groups, m, n), device='cuda', dtype=torch.float) * 32 + d = c + ref_d = torch.empty_like(c) + + start = 0 + for i, group_k in enumerate(ks): + end = start + group_k + ref_d[i] = c[i] + (a[start:end].T @ b[start:end]) + start = end + + if use_bf16: + assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + return k, a, b, c, d, ref_d + + a_fp8 = per_channel_cast_to_fp8(a, use_ue8m0=use_ue8m0, gran_k=gran_k) + b_fp8 = per_channel_cast_to_fp8(b, use_ue8m0=use_ue8m0, gran_k=gran_k) + + # Transpose for K Major A/B + if (major_a, major_b) == (MajorTypeAB.KMajor, MajorTypeAB.KMajor): + a, sfa = a_fp8 + b, sfb = b_fp8 + new_a = torch.empty((sum(ks) * m, ), dtype=a.dtype, device=a.device) + new_b = torch.empty((sum(ks) * n, ), dtype=b.dtype, device=b.device) + prefix = 0 + for K in ks: + new_a[prefix * m : (prefix + K) * m] = a[prefix : prefix + K, ].T.flatten() + new_b[prefix * n : (prefix + K) * n] = b[prefix : prefix + K, ].T.flatten() + prefix += K + a_fp8, b_fp8 = (new_a, sfa.T), (new_b, sfb.T) + else: + assert (major_a, major_b) == (MajorTypeAB.MNMajor, MajorTypeAB.MNMajor) + + return k, a_fp8, b_fp8, c, d, ref_d diff --git a/third_party/DeepGEMM/tests/test_attention.py b/third_party/DeepGEMM/tests/test_attention.py new file mode 100644 index 00000000..479da5b5 --- /dev/null +++ b/third_party/DeepGEMM/tests/test_attention.py @@ -0,0 +1,397 @@ +import dataclasses +import random +import torch +from typing import Tuple, List + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + ignore_env, get_arch_major, + test_filter +) +from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8, per_token_cast_to_fp4, cast_back_from_fp4 + +from generators import get_arch_major, generate_normal, get_ue8m0_usage, get_kernel_types, reset_seed, MajorTypeAB + + +def apply_skip_head_mid(d: torch.Tensor, head_splits: Tuple[int, int, int]): + left, mid, right = head_splits + m, n = d.shape + assert n % (left + right) == 0 + num_heads = n // (left + right) + + # Split and insert padding tensor + d = d.view(m, num_heads, -1) + d_left = d[:, :, :left] + d_right = d[:, :, -right:] + + d_mid = torch.zeros((m, num_heads, mid), dtype=d.dtype, device=d.device) + return torch.cat([d_left, d_mid, d_right], dim=2).view(m, -1) + + +def test_gemm_skip_head_mid() -> None: + print('Testing GEMM skip head mid:') + head_splits = (128, 64, 128) + + major_a, major_b = MajorTypeAB.KMajor, MajorTypeAB.KMajor + out_dtype, accumulate = torch.bfloat16, False + + for kernel_type in get_kernel_types(dtype=torch.float8_e4m3fn): + for m in (128, 4096): + for n, k in [(32768, 512), (8192, 512)]: + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + + a, b, _, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) + d = apply_skip_head_mid(d, head_splits) + ref_d = apply_skip_head_mid(ref_d, head_splits) + + deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {kernel_opt}, {diff:.5f}' + + t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt_skip_head_mid(a, b, d, head_splits, disable_ue8m0_cast=disable_ue8m0_cast), + 'gemm_', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, {kernel_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d)) / 1e9 / t:4.0f} GB/s') + print() + + +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor, cost_only: bool = False): + seq_len_kv = kv.shape[0] + + if cost_only: + start = cu_seqlen_ks.clamp(min=0, max=seq_len_kv) + end = cu_seqlen_ke.clamp(min=0, max=seq_len_kv) + count_ones_per_row = (end - start).clamp(min=0) + return count_ones_per_row.sum() + + k = kv + q = q.float() + k = k.float() + + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + + +def test_mqa_logits(): + + # Helper functions + def generate_ks_ke_tests(seq_len: int, seq_len_kv: int, disable_cp: bool): + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.arange(seq_len, dtype=torch.int, device='cuda') + (seq_len_kv - seq_len) + return ks, ke + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + # Select an arbitrary CP rank + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device='cuda') + ke = torch.zeros(seq_len, dtype=torch.int, device='cuda') + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + def enumerate_mqa_logits(): + for is_fp4 in ((True, False) if get_arch_major() == 10 else (False, )): + for logits_dtype in (torch.float, torch.bfloat16): + for compressed_logits, clean_logits in [(False, True), (True, False)]: + for seq_len in (2048, 4096): + for seq_len_kv in (4096, 8192): + for num_heads, head_dim in [(64, 128)]: + for disable_cp in (False, True): + yield is_fp4, logits_dtype, compressed_logits, clean_logits, seq_len, seq_len_kv, num_heads, head_dim, disable_cp + + print('Testing FP8 MQA Logits:') + for is_fp4, logits_dtype, compressed_logits, clean_logits, seq_len, seq_len_kv, num_heads, head_dim, disable_cp in enumerate_mqa_logits(): + # Generate random inputs + q = torch.randn(seq_len, num_heads, head_dim, device='cuda', dtype=torch.bfloat16) + kv = torch.randn(seq_len_kv, head_dim, device='cuda', dtype=torch.bfloat16) + weights = torch.randn(seq_len, num_heads, device='cuda', dtype=torch.float32) + ks, ke = generate_ks_ke_tests(seq_len, seq_len_kv, disable_cp) + + # Calculate reference logits + ref_logits, ref_cost = ref_fp8_mqa_logits(q, kv, weights, ks, ke) + + # Quantize Q and KV to FP4 / FP8 + if is_fp4: + q_fp4 = per_token_cast_to_fp4(q.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + q_in = (q_fp4[0].view(seq_len, num_heads, head_dim // 2), q_fp4[1].view(seq_len, num_heads)) + q_simulated = cast_back_from_fp4(q_fp4[0], q_fp4[1], gran_k=32, use_packed_ue8m0=True).view(seq_len, num_heads, head_dim).to(torch.bfloat16) + + kv_fp4 = per_token_cast_to_fp4(kv.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + kv_in = (kv_fp4[0].view(seq_len_kv, head_dim // 2), kv_fp4[1].view(seq_len_kv)) + kv_simulated = cast_back_from_fp4(kv_fp4[0], kv_fp4[1], gran_k=32, use_packed_ue8m0=True).view(seq_len_kv, head_dim).to(torch.bfloat16) + else: + q_in = q.to(torch.float8_e4m3fn), None + q_simulated = q_in[0].to(torch.bfloat16) + kv_in = per_custom_dims_cast_to_fp8(kv, (0, ), False) + kv_simulated = (kv_in[0].float() * kv_in[1].unsqueeze(1)).to(torch.bfloat16) + + # Calculate reference logits + simulated_logits, _ = ref_fp8_mqa_logits(q_simulated, kv_simulated, weights, ks, ke) + + # Prepare kwargs + kernel_kwargs = dict( + q=q_in, kv=kv_in, weights=weights, + cu_seq_len_k_start=ks, cu_seq_len_k_end=ke, + clean_logits=clean_logits, max_seqlen_k=0, + logits_dtype=logits_dtype + ) + if compressed_logits: + max_seqlen_k = (ke - ks).max().item() + kernel_kwargs['max_seqlen_k'] = max_seqlen_k + + # Run kernel + logits = deep_gemm.fp8_fp4_mqa_logits(**kernel_kwargs) + + # Post process for compressed logits + if compressed_logits: + assert logits.size() == (seq_len, max_seqlen_k) + tmp = torch.full((seq_len, seq_len_kv), float('-inf'), device='cuda') + for i in range(seq_len): + tmp[i, ks[i] : ke[i]] = logits[i, : ke[i] - ks[i]] + logits = tmp + + # Validation + ref_neginf_mask = (ref_logits == float('-inf')) + neginf_mask = (logits == float('-inf')) + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + simulated_logits = simulated_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(ref_neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + simulated_diff = calc_diff(logits, simulated_logits) + assert diff < 0.02 if is_fp4 else 1e-3, f"Diff: {diff}" + assert simulated_diff < 5e-6, f"Simulated Diff: {simulated_diff}" + + # Profiling + tflops = 2 * ref_cost * num_heads * head_dim / 1e12 + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_mqa_logits(**kernel_kwargs), ('mqa_logits', 'clean_logits')) + clean_bytes = (seq_len * seq_len_kv - ref_cost) * 4 + count_bytes(ks, ke) + + print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, S={seq_len:4}, SKV={seq_len_kv:6}, H={num_heads:3}, D={head_dim:3}, CP={0 if disable_cp else 1}: ' + f'{tflops / t:4.0f} TFLOPS, {t * 1e6:4.0f} us, ' + f'{(count_bytes(q_in, kv_in, weights, ks, ke) + ref_cost * 4) / t / 1e9:4.0f} GB/s', end='') + print(f' | clean: {clean_t * 1e6:3.0f} us, {clean_bytes / clean_t / 1e9:4.0f} GB/s' if clean_logits else '') + print() + + +def ref_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, + weights: torch.Tensor, context_lens: torch.Tensor, block_tables: torch.Tensor, + max_model_len: int, use_2d_context_lens: bool): + batch_size, next_n, num_heads, dim = q.size() + num_block, block_size, _, dim = kv_cache.size() + logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32) + context_lens = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens[i] + q_offsets = torch.full((next_n, ), context_len, device='cuda', dtype=torch.int32) if use_2d_context_lens \ + else torch.arange(context_len - next_n, context_len, device='cuda') + weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous() + + num_blocks = (context_len + block_size - 1) // block_size + block_idxs = block_tables[i][:num_blocks] + kv_slice = kv_cache[block_idxs] # [num_blocks, block_size, kv_heads, dim] + kx = kv_slice.permute(2, 3, 0, 1).reshape(kv_slice.size(2), dim, -1) # [kv_heads, dim, total_tokens] + qx = q[i].transpose(0, 1) # q[i]: [next_n, num_heads, dim] -> [num_heads, next_n, dim] + s = torch.matmul(qx, kx).to(logits.dtype) # [num_heads, next_n, dim] @ [1, dim, total_tokens] -> [num_heads, next_n, total_tokens] + + total_len = num_blocks * block_size + k_offsets = torch.arange(0, total_len, device=q.device) + mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None]) + s = torch.where(mask[None, :, :], s, float('-inf')) # mask shape: [1, next_n, total_tokens] + s = torch.relu(s) * weight_slice[..., None] # weight_slice: [num_heads, next_n] -> [num_heads, next_n, 1] + s = s.sum(dim=0) # [next_n, total_tokens] + logits[i * next_n:(i + 1) * next_n, :total_len] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf')) + + return logits + + +def test_paged_mqa_logits(): + + # Helper functions + def kv_cache_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_cast_back = x_scaled.float() * sf + + x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), device=x.device, dtype=torch.uint8) + x_fp8[ :, : block_size * head_dim] = x_scaled.view(num_blocks, block_size * head_dim).view(torch.uint8) + x_fp8[ :, block_size * head_dim :] = sf.view(num_blocks, block_size).view(torch.uint8) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4), x_cast_back.to(x.dtype) + + def kv_cache_cast_to_fp4(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 and head_dim == 128 + x_scaled, sf = per_token_cast_to_fp4(x.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + x_cast_back = cast_back_from_fp4(x_scaled, sf, gran_k=32, use_packed_ue8m0=True).view(num_blocks, block_size, 1, head_dim) + + x_fp4 = torch.empty((num_blocks, block_size * (head_dim // 2 + 4)), device=x.device, dtype=torch.uint8) + x_fp4[ :, : block_size * head_dim // 2] = x_scaled.view(num_blocks, block_size * head_dim // 2).view(torch.uint8) + x_fp4[ :, block_size * head_dim // 2 :] = sf.view(num_blocks, block_size).view(torch.uint8) + return x_fp4.view(num_blocks, block_size, num_heads, head_dim // 2 + 4), x_cast_back.to(x.dtype) + + def enumerate_paged_mqa_logits(): + arch_major = get_arch_major() + for is_varlen in ((True, False) if arch_major == 10 else (False, )): + for is_fp4 in ((True, False) if arch_major == 10 else (False, )): + for logits_dtype in (torch.float, torch.bfloat16): + for block_kv in ((32, 64) if arch_major == 10 else (64, )): + for use_2d_context_lens, clean_logits in [(True, False)]: + for batch_size in (256, ): + for next_n in ((1, ) if is_varlen else ((1, 2, 4, 5, 6) if arch_major == 10 else (1, 2))): + for max_tokens_per_batch in ((1, 4, 10) if is_varlen else (1, )): + for num_heads, head_dim in [(64, 128)]: + for avg_kv in (8192, 32768): + yield is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv + + + print('Testing FP8/FP4 Paged MQA Logits:') + max_model_len = 111 * 1024 + num_total_blocks = max_model_len * 5 + + for is_varlen, is_fp4, logits_dtype, block_kv, use_2d_context_lens, clean_logits, batch_size, next_n, max_tokens_per_batch, num_heads, head_dim, avg_kv in enumerate_paged_mqa_logits(): + # Varlen: flatten raw_batch_size sequences with variable tokens into (batch_size, 1, ...) + raw_batch_size, raw_next_n = batch_size, next_n + if is_varlen: + tokens_per_seq = torch.randint(1, max_tokens_per_batch + 1, (raw_batch_size,), device='cuda', dtype=torch.int) + indices = torch.arange(raw_batch_size, device='cuda', dtype=torch.int).repeat_interleave(tokens_per_seq) + batch_size, next_n = tokens_per_seq.sum().item(), 1 + else: + tokens_per_seq, indices = None, None + + # Generate random inputs + q = torch.randn((batch_size, next_n, num_heads, head_dim), device='cuda', dtype=torch.bfloat16) + kv_cache = torch.randn((num_total_blocks, block_kv, 1, head_dim), device='cuda', dtype=torch.bfloat16) + weights = torch.randn((batch_size * next_n, num_heads), device='cuda', dtype=torch.float) + context_lens = torch.randint(int(0.7 * avg_kv), int(1.3 * avg_kv), (raw_batch_size,), device='cuda', dtype=torch.int) + + if is_varlen: + max_ctx_len_per_seq = context_lens + (tokens_per_seq - 1) + else: + max_ctx_len_per_seq = context_lens + + # Assign block tables (per-sequence, sized by the largest ctx_len within the sequence) + seq_sum_lens = context_lens.sum().item() + num_blocks_per_query = ceil_div(max_ctx_len_per_seq, block_kv) + block_table = torch.empty((raw_batch_size, num_blocks_per_query.max().item()), device='cuda', dtype=torch.int) + block_idx_pool = torch.randperm(num_total_blocks, device='cuda', dtype=torch.int) + offset = 0 + for i, num_blocks in enumerate(num_blocks_per_query.tolist()): + block_table[i, :num_blocks] = block_idx_pool[offset : offset + num_blocks] + offset += num_blocks + if is_varlen: + context_lens = context_lens.repeat_interleave(tokens_per_seq) + offsets_within_seq = torch.cat([ + torch.arange(n.item(), device='cuda', dtype=torch.int) + for n in tokens_per_seq + ]) + context_lens = context_lens + offsets_within_seq + block_table = block_table.repeat_interleave(tokens_per_seq, dim=0) + + # Calculate reference logits + ref_logits = ref_paged_mqa_logits(q, kv_cache, weights, context_lens, block_table, max_model_len, use_2d_context_lens) + + # Quantize Q and KV cache to FP4 / FP8 + if is_fp4: + q_fp4 = per_token_cast_to_fp4(q.view(-1, head_dim), use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + q_in = (q_fp4[0].view(batch_size, next_n, num_heads, head_dim // 2), q_fp4[1].view(batch_size, next_n, num_heads)) + q_simulated = cast_back_from_fp4(q_fp4[0], q_fp4[1], gran_k=32, use_packed_ue8m0=True).view(batch_size, next_n, num_heads, head_dim).to(torch.bfloat16) + kv_in, kv_simulated = kv_cache_cast_to_fp4(kv_cache) + else: + q_in = q.to(torch.float8_e4m3fn), None + q_simulated = q_in[0].to(torch.bfloat16) + kv_in, kv_simulated = kv_cache_cast_to_fp8(kv_cache) + + # Calculate simulated reference logits + simulated_logits = ref_paged_mqa_logits(q_simulated, kv_simulated, weights, context_lens, block_table, max_model_len, use_2d_context_lens) + + # Prepare masks and context lengths with NextN + positions = torch.arange(max_model_len, device='cuda').unsqueeze(0).expand(batch_size * next_n, -1) + if use_2d_context_lens: + if is_varlen: + # Varlen: context_lens is already per-token (shape [total_tokens]); + # just reshape to (total_tokens, 1) so each token keeps its own ctx_len. + context_lens_nextn = context_lens.view(-1, 1) + else: + context_lens_nextn = ((context_lens.unsqueeze(1) + 1) * torch.rand(batch_size, next_n, device='cuda')).int() + # Ensure last token matches actual length + context_lens_nextn[:, -1] = context_lens + ref_neginf_mask = ~(positions < context_lens_nextn.view(-1, 1)) + else: + context_lens_nextn = context_lens + offsets = torch.arange(batch_size * next_n, device='cuda') + limits = (context_lens[offsets // next_n] - next_n + offsets % next_n).unsqueeze(1) + ref_neginf_mask = ~(positions <= limits) + + # Run Kernel + kernel_kwargs = dict( + q=q_in, kv_cache=kv_in, weights=weights, + context_lens=context_lens_nextn, block_table=block_table, + schedule_meta=deep_gemm.get_paged_mqa_logits_metadata(context_lens_nextn, block_kv, deep_gemm.get_num_sms(), indices=indices), + max_context_len=max_model_len, clean_logits=clean_logits, logits_dtype=logits_dtype, + indices=indices, + ) + logits = deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs) + + # Validation + assert logits.dtype == logits_dtype + logits = logits.to(torch.float) + + if clean_logits: + assert torch.equal(logits == float('-inf'), ref_neginf_mask), "Mask mismatch" + + logits_masked = logits.masked_fill(ref_neginf_mask, 0) + ref_masked = ref_logits.masked_fill(ref_neginf_mask, 0) + simulated_masked = simulated_logits.masked_fill(ref_neginf_mask, 0) + diff = calc_diff(logits_masked, ref_masked) + simulated_diff = calc_diff(logits_masked, simulated_masked) + assert diff < 0.02 if is_fp4 else 1e-3, f"Diff: {diff}" + assert simulated_diff < 5e-6, f"Simulated Diff: {simulated_diff}" + + # Profiling + sum_lens = context_lens.sum().item() + tflops_calc = 2 * sum_lens * next_n * num_heads * head_dim / 1e12 + kv_bytes_per_token = head_dim / (2 if is_fp4 else 1) + 4 + # KV is read once per sequence; for varlen sum_lens overcounts (per-token), so use seq_sum_lens + kv_sum_lens = seq_sum_lens if is_varlen else sum_lens + total_bytes = count_bytes(q, weights) + kv_sum_lens * kv_bytes_per_token + (sum_lens * next_n * logits_dtype.itemsize) + + t, clean_t = bench_kineto(lambda: deep_gemm.fp8_fp4_paged_mqa_logits(**kernel_kwargs), ('paged_mqa_logits', 'clean_logits')) + print(f' > FP4={is_fp4}, BF16={logits_dtype == torch.bfloat16}, BLOCK_KV={block_kv}, BSZ={raw_batch_size:3}, NextN={raw_next_n:1}, H={num_heads:2}, D={head_dim:2}, L={avg_kv:6}: ' + f'{tflops_calc / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, {total_bytes / t / 1e9:4.0f} GB/s', end='') + if is_varlen: + print(f' | Varlen, MaxTPB={max_tokens_per_batch}, NumTokens={batch_size}', end='') + print(f' | clean: {clean_t*1e6:3.0f} us' if clean_logits else '') + print() + + + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + test_gemm_skip_head_mid() + test_mqa_logits() + test_paged_mqa_logits() diff --git a/third_party/DeepGEMM/tests/test_bf16.py b/third_party/DeepGEMM/tests/test_bf16.py new file mode 100644 index 00000000..fb3acf3d --- /dev/null +++ b/third_party/DeepGEMM/tests/test_bf16.py @@ -0,0 +1,223 @@ +import copy +import numpy as np +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes +) +from generators import ( + get_arch_major, layout_masked_to_psum, align, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous, + get_mk_alignment_for_contiguous_layout +) + + +def test_gemm() -> None: + print('Testing GEMM:') + scores = [] + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + func_name = f'bf16_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else a.T + b = b if major_b.is_k_major() else b.T + assert a.is_contiguous() and b.is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c) + diff = calc_diff(d, ref_d) + assert diff < 1e-5, (f'{m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + + t = bench_kineto(lambda: deep_gemm.bf16_gemm_nt(a, b, d, c=c), 'bf16_gemm', suppress_kineto_output=True) + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:7.1f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') + if cublas_t > 0: + scores.append((cublas_t + split_k_t) / t) + print(f"Average speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + + for test_alias in (False, True): + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) + func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else b.mT + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) + if use_psum_layout: + for j in range(num_groups): + start = 0 if j == 0 else align(grouped_layout[j - 1], get_mk_alignment_for_contiguous_layout()) + end = grouped_layout[j] + diff = calc_diff(d[start : end], ref_d[start : end]) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + else: + diff = calc_diff(d, ref_d) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) + + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, psum={use_psum_layout}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for _, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16): + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout(int(expected_m_per_group * 1.2)) + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_bf16=True, use_psum_layout=use_psum_layout) + if use_psum_layout: + a_psum = layout_masked_to_psum(a, psum_m) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, + use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group) + else: + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + + test_func() + for j in range(num_groups): + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout()): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) + assert diff < 1e-5, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' + + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') + print() + + +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + # TODO: Support arbitrary alignment + deep_gemm.set_mk_alignment_for_contiguous_layout(128) + + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16): + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups and len(ks) > 1: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_bf16=True) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, new_ks, new_ks_tensor, c) + + diff = calc_diff(d, ref_d) + assert diff < 1e-5, f'{m=}, {n=}, {k=}, {ks=}, {diff:.7f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_bf16=True) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks, ks_tensor, c) + + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_cublaslt_gemm() -> None: + print('Testing cuBLASLt GEMM:') + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_bf16=True) + deep_gemm.cublaslt_gemm_nt(a, b, d, c=c) + diff = calc_diff(d, ref_d) + assert diff < 6e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})' + + t_nvjet, t_gemv, t_gemm = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'gemv', 'gemm'), suppress_kineto_output=True) + t = t_nvjet + t_gemv + t_gemm + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:5.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + if get_arch_major() >= 9: + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() + + test_cublaslt_gemm() diff --git a/third_party/DeepGEMM/tests/test_einsum.py b/third_party/DeepGEMM/tests/test_einsum.py new file mode 100644 index 00000000..57f54592 --- /dev/null +++ b/third_party/DeepGEMM/tests/test_einsum.py @@ -0,0 +1,181 @@ +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + get_arch_major, test_filter +) +from deep_gemm.utils.math import ( + ceil_div, + per_block_cast_to_fp8, per_channel_cast_to_fp8, per_token_cast_to_fp8 +) + + +def test_bmk_bnk_mn() -> None: + print('Testing "bmk, bnk -> mn":') + for s in (129, 4096, 8192): + for m, n, k in [(128, 384, 128), (256, 256, 256), (384, 128, 384)]: + for dtype in (torch.float, torch.bfloat16): + a = torch.randn((s, m, k), dtype=torch.bfloat16, device='cuda') + b = torch.randn((s, n, k), dtype=torch.bfloat16, device='cuda') + d = torch.randn((m, n), dtype=dtype, device='cuda') + c = d if dtype == torch.float else None + + # Test correctness + ref_d = (c if dtype == torch.float else 0) + torch.bmm(a.float(), b.float().mT).sum(0) + deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c) + assert calc_diff(d, ref_d) < 1e-5 + + t = bench_kineto(lambda: deep_gemm.einsum('bmk,bnk->mn', a, b, d, c=c), 'bmn_bnk_mn_gemm_impl', suppress_kineto_output=True) + print(f' > Perf (b={s:4.0f}, {m=}, {n=}, {k=}, {"FP32" if dtype == torch.float else "BF16"}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * s * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b) + (d.numel() * 4)) / 1e9 / t:4.0f} GB/s') + print() + + +def test_bhr_hdr_bhd(): + print('Testing "bhr, hdr -> bhd":') + for h, r, d in [(128, 512, 128), (8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) + y = fy[:, :, :r] + ref_z = torch.einsum('bhr,hdr->bhd', x, y) + z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16) + deep_gemm.einsum('bhr,hdr->bhd', x, y, z) + assert calc_diff(z, ref_z) < 1e-10 + + t = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z), 'gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +def test_bhd_hdr_bhr(): + print('Testing "bhd, hdr -> bhr":') + for h, r, d in [(128, 512, 128), (8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + fy = torch.randn((h, d, r + 128), device='cuda', dtype=torch.bfloat16) + y = fy[:, :, :r] + ref_z = torch.einsum('bhd,hdr->bhr', x, y) + z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16) + deep_gemm.einsum('bhd,hdr->bhr', x, y, z) + assert calc_diff(z, ref_z) < 1e-10 + + t = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z), 'gemm', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x, y, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +def test_fp8_bhr_hdr_bhd(use_ue8m0: bool = True): + print('Testing FP8 "bhr, hdr -> bhd":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + y = torch.randn((h, d, r), device='cuda', dtype=torch.bfloat16) + ref_z = torch.einsum('bhr,hdr->bhd', x, y) + + x_fp8 = per_token_cast_to_fp8(x.view(-1, r), use_ue8m0=use_ue8m0) + x_fp8 = x_fp8[0].view(b, h, r), x_fp8[1].view(b, h, ceil_div(r, 128)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((h, ceil_div(d, 128), ceil_div(r, 128)), device='cuda', dtype=torch.float)) + for i in range(h): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], use_ue8m0=use_ue8m0) + z = torch.empty((b, h, d), device='cuda', dtype=torch.bfloat16) + + deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhr,hdr->bhd', x_fp8, y_fp8, z), 'gemm_', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhr,hdr->bhd', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +@test_filter(lambda: get_arch_major() >= 10) +def test_fp8_bhd_hdr_bhr(use_ue8m0: bool = True): + print('Testing FP8 "bhd, hdr -> bhr":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4, 32, 128, 4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + y = torch.randn((h, d, r), device='cuda', dtype=torch.bfloat16) + ref_z = torch.einsum('bhd,hdr->bhr', x, y) + + x_fp8 = per_token_cast_to_fp8(x.view(-1, d), use_ue8m0=use_ue8m0) + x_fp8 = x_fp8[0].view(b, h, d), x_fp8[1].view(b, h, ceil_div(d, 128)) + y_fp8 = (torch.empty_like(y, dtype=torch.float8_e4m3fn), + torch.empty((h, ceil_div(d, 128), ceil_div(r, 128)), device='cuda', dtype=torch.float)) + for i in range(h): + y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], use_ue8m0=use_ue8m0) + z = torch.empty((b, h, r), device='cuda', dtype=torch.bfloat16) + + deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,hdr->bhr', x_fp8, y_fp8, z), 'gemm_', suppress_kineto_output=True) + t_cublaslt = bench_kineto(lambda: deep_gemm.einsum('bhd,hdr->bhr', x, y, z, use_cublaslt=True), 'nvjet', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z)) / t / 1e9:4.0f} GB/s | ' + f'{t_cublaslt / t:4.2f} x') + print() + + +@test_filter(lambda: get_arch_major() >= 10) +def test_fp8_bhd_bhr_hdr(use_ue8m0: bool = True): + print('Testing FP8 "bhd, bhr -> hdr":') + for h, r, d in [(8, 4096, 1024)]: + for b in (4096, 8192): + x = torch.randn((b, h, d), device='cuda', dtype=torch.bfloat16) + y = torch.randn((b, h, r), device='cuda', dtype=torch.bfloat16) + z_0 = torch.randn((h, d, r), device='cuda', dtype=torch.float32) * 10 + ref_z = z_0 + torch.einsum('bhd,bhr->hdr', x, y) + + x_fp8 = per_channel_cast_to_fp8(x.view(b, -1), use_ue8m0=use_ue8m0) + y_fp8 = per_channel_cast_to_fp8(y.view(b, -1), use_ue8m0=use_ue8m0) + x_fp8 = (x_fp8[0].view(b, h, d), x_fp8[1].view(ceil_div(b, 128), h, d)) + y_fp8 = (y_fp8[0].view(b, h, r), y_fp8[1].view(ceil_div(b, 128), h, r)) + z = z_0.clone() + deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)) + assert calc_diff(z, ref_z) < 1e-3 + + t = bench_kineto(lambda: deep_gemm.fp8_einsum('bhd,bhr->hdr', x_fp8, y_fp8, z, z, recipe=(1, 1, 128)), 'gemm_', suppress_kineto_output=True) + print(f' > Perf ({b=:4.0f}, {h=}, {r=}, {d=}): ', + f'{t * 1e6:4.0f} us | ' + f'{2 * b * h * r * d / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes((x_fp8, y_fp8, z, z)) / t / 1e9:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_bmk_bnk_mn() + test_bhr_hdr_bhd() + test_bhd_hdr_bhr() + + test_fp8_bhr_hdr_bhd() + test_fp8_bhd_hdr_bhr() + test_fp8_bhd_bhr_hdr() diff --git a/third_party/DeepGEMM/tests/test_fp8_fp4.py b/third_party/DeepGEMM/tests/test_fp8_fp4.py new file mode 100644 index 00000000..4e9f54f7 --- /dev/null +++ b/third_party/DeepGEMM/tests/test_fp8_fp4.py @@ -0,0 +1,222 @@ +import copy +import numpy as np +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + ignore_env, get_arch_major +) + +from generators import ( + KernelType, get_ue8m0_usage, layout_masked_to_psum, align, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous, + get_mk_alignment_for_contiguous_layout +) + + +def test_gemm() -> None: + print('Testing GEMM:') + scores = [] + for kernel_type, quant_config, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=(kernel_type.is_1d1d() and accumulate)) + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + func_name = f'fp8_fp4_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else (a[0].T, a[1].T) + b = b if major_b.is_k_major() else (b[0].T, b[1].T) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + t = bench_kineto(lambda: deep_gemm.fp8_fp4_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b), + 'gemm_', suppress_kineto_output=True) + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) \ + if not quant_config.is_fp4_a and not quant_config.is_fp4_b else (0, 0) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') + if cublas_t > 0: + scores.append((cublas_t + split_k_t) / t) + print(f"Average FP8xFP8 GEMM speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + + for test_alias in (False, True): + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + func_name = f"m_grouped_fp8_fp4_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + if use_psum_layout: + for j in range(num_groups): + start = 0 if j == 0 else align(grouped_layout[j - 1], get_mk_alignment_for_contiguous_layout()) + end = grouped_layout[j] + diff = calc_diff(d[start : end], ref_d[start : end]) + assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + else: + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + t = bench_kineto(test_func, 'gemm_', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}, psum={use_psum_layout}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + # Select best alignment + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout(int(expected_m_per_group * 1.2)) + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + if use_psum_layout: + a_psum = (layout_masked_to_psum(a[0], psum_m), layout_masked_to_psum(a[1], psum_m)) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, disable_ue8m0_cast=disable_ue8m0_cast, + use_psum_layout=True, expected_m_for_psum_layout=int(expected_m_per_group * 1.2), + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + else: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, int(expected_m_per_group * 1.2), disable_ue8m0_cast=disable_ue8m0_cast, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + test_func() + for j in range(num_groups): + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], get_mk_alignment_for_contiguous_layout()): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) + assert diff < quant_config.max_diff(), f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'gemm_', suppress_kineto_output=True) + + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'{kernel_opt}, psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') + print() + + +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \ + else deep_gemm.k_grouped_fp8_gemm_tn_contiguous + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group, gran_k in enumerate_k_grouped_contiguous(torch.float8_e4m3fn): + recipe = (1, 1, gran_k) + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups and len(ks) > 1: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0, gran_k=gran_k) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c, recipe=recipe) + + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0, gran_k=gran_k) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c, recipe=recipe) + + t = bench_kineto(test_func, 'gemm_', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}, gran_k={gran_k:3}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() diff --git a/third_party/DeepGEMM/tests/test_hyperconnection.py b/third_party/DeepGEMM/tests/test_hyperconnection.py new file mode 100644 index 00000000..24faf22c --- /dev/null +++ b/third_party/DeepGEMM/tests/test_hyperconnection.py @@ -0,0 +1,57 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + test_filter, + bench_kineto, + calc_diff, count_bytes +) +from deep_gemm.utils import align +from generators import get_arch_major + + +@test_filter(lambda: get_arch_major() >= 9) +def test_hc_prenorm_gemm() -> None: + # Needs TF32 precision for PyTorch GEMMs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + print('Testing hyperconnection prenorm GEMM:') + for m in (13, 137, 4096, 8192): + for n, k in [(24, 28672), (24, 7680), (24, 7168)]: + for num_splits in [None, 16]: + a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda') + b = torch.randn((n, k), dtype=torch.float, device='cuda') + d = torch.empty((m, n), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m, n), dtype=torch.float, device='cuda') + s = torch.empty((m, ), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m), dtype=torch.float, device='cuda') + deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits) + final_d = d if num_splits is None else d.sum(0) + final_s = s if num_splits is None else s.sum(0) + + ref_d = a.float() @ b.T + ref_s = a.float().square().sum(-1) + + diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s)) + assert diff < 1e-8, f'{m=}, {n=}, {k=}, {diff:.10f}' + + t = bench_kineto(lambda: deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits), 'tf32_hc_prenorm_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, num_splits={(num_splits or 0):2}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d, s) / 1e9 / t:4.0f} GB/s') + print() + + + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_hc_prenorm_gemm() diff --git a/third_party/DeepGEMM/tests/test_layout.py b/third_party/DeepGEMM/tests/test_layout.py new file mode 100644 index 00000000..a0d4a02e --- /dev/null +++ b/third_party/DeepGEMM/tests/test_layout.py @@ -0,0 +1,112 @@ +import torch +import random +from deep_gemm.testing import bench_kineto, count_bytes, get_arch_major +from deep_gemm.utils import ( + align, ceil_div, + per_token_cast_to_fp8, per_channel_cast_to_fp8, + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor +) + +from generators import ( + enumerate_sf_layout, + enumerate_k_grouped_sf_layout +) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(x: torch.Tensor) -> torch.Tensor: + assert x.dtype == torch.float and x.dim() in (2, 3) + + # First, convert into UE8M0 `uint8_t` + ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8) + + # Second, make padded packed tensors + mn, k = x.shape[-2], x.shape[-1] + remove_dim = False + if x.dim() == 2: + x, remove_dim = x.unsqueeze(0), True + b = x.shape[0] + aligned_mn = get_tma_aligned_size(mn, 4) + aligned_k = align(k, 4) + padded = torch.zeros((b, aligned_mn, aligned_k), device=x.device, dtype=torch.uint8) + padded[:, :mn, :k] = ue8m0_tensor + padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn, aligned_k // 4) + + # Finally, transpose + transposed = torch.zeros((b, aligned_k // 4, aligned_mn), device=x.device, dtype=torch.int).mT + transposed[:, :, :] = padded + aligned_x = transposed[:, :mn, :] + return aligned_x.squeeze(0) if remove_dim else aligned_x + + +def test_sf_layout_kernels() -> None: + print('Testing SF layout kernels:') + for mn, k, with_transpose, use_ue8m0, num_groups, gran_k in enumerate_sf_layout(): + x = torch.randn((num_groups * mn, k), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + fp32_sf = fp32_sf if num_groups == 1 else fp32_sf.view(num_groups, mn, -1) + fp32_sf = fp32_sf if with_transpose else fp32_sf.transpose(-1, -2).contiguous().transpose(-1, -2) + + # Correctness + if use_ue8m0: + impl, name = get_mn_major_tma_aligned_packed_ue8m0_tensor, 'pack_fp32_into_ue8m0' + packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf) + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(fp32_sf) + assert torch.equal(packed_sf, ref_packed_sf), f'{mn=}, {k=}, {with_transpose=}, {num_groups=}' + assert packed_sf.shape == ref_packed_sf.shape + assert all([packed_sf.stride(i) == ref_packed_sf.stride(i) for i in range(packed_sf.dim())]) + else: + impl, name = get_mn_major_tma_aligned_tensor, 'transpose' + transposed_sf = get_mn_major_tma_aligned_tensor(fp32_sf) + tma_aligned_mn, sf_k = get_tma_aligned_size(mn, fp32_sf.element_size()), ceil_div(k, gran_k) + if num_groups > 1: + assert transposed_sf.size(0) == num_groups + assert transposed_sf.stride(0) == tma_aligned_mn * sf_k + assert transposed_sf.shape[-2:] == (mn, sf_k) + assert transposed_sf.stride()[-2:] == (1, tma_aligned_mn) + assert torch.equal(fp32_sf, transposed_sf) + + # Performance + try: + t = bench_kineto(lambda: impl(fp32_sf), name) + except AssertionError as e: + # Some cases may fallback to PyTorch impl + t = 0 + print(f' > Perf ({num_groups=:2}, {mn=:5}, {k=:5}, transpose={int(with_transpose)}, use_ue8m0={int(use_ue8m0)}, gran_k={gran_k:3}): ' + f'{t * 1e6:4.0f} us | {count_bytes(fp32_sf, impl(fp32_sf)) / 1e9 / t if t else 0:4.0f} GB/s') + print() + + +def test_k_grouped_sf_layout_kernels() -> None: + print('Testing k-grouped SF layout kernels:') + for mn, ks, num_groups, gran_k in enumerate_k_grouped_sf_layout(): + sf_ks = [k // gran_k for k in ks] + packed_sf_ks = [ceil_div(k, gran_k * 4) for k in ks] + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + x = torch.randn((sum(ks), mn), dtype=torch.bfloat16, device='cuda') + x, fp32_sf = per_channel_cast_to_fp8(x, use_ue8m0=True, gran_k=gran_k) + + # Correctness + packed_sf = get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks, gran_k) + split_packed_sf = packed_sf.split(packed_sf_ks) + split_fp32_sf = fp32_sf.split(sf_ks) + for i in range(num_groups): + ref_packed_sf = get_mn_major_tma_aligned_packed_ue8m0_tensor_torch_impl(split_fp32_sf[i].T).T + assert torch.equal(split_packed_sf[i], ref_packed_sf), f'{i=}' + + # Performance + t = bench_kineto(lambda: get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(fp32_sf, ks_tensor, ks, gran_k), 'pack_fp32_into_ue8m0') + print(f' > Perf ({num_groups=:3}, {mn=:5}, sum_k={sum(ks):5}, gran_k={gran_k:3}):' + f'{t * 1e6:4.0f} us | ' + f'{count_bytes(fp32_sf, packed_sf, ks_tensor) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(1) + random.seed(1) + + test_sf_layout_kernels() + test_k_grouped_sf_layout_kernels() diff --git a/third_party/DeepGEMM/tests/test_lazy_init.py b/third_party/DeepGEMM/tests/test_lazy_init.py new file mode 100644 index 00000000..17a3a121 --- /dev/null +++ b/third_party/DeepGEMM/tests/test_lazy_init.py @@ -0,0 +1,20 @@ +import argparse +import torch +import torch.multiprocessing as mp +import deep_gemm + + +def main(local_rank: int): + torch.cuda.set_device(local_rank) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Test lazy initialization') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + args = parser.parse_args() + + procs = [mp.Process(target=main, args=(i, ), ) for i in range(args.num_processes)] + for p in procs: + p.start() + for p in procs: + p.join() diff --git a/third_party/DeepGEMM/tests/test_legacy.py b/third_party/DeepGEMM/tests/test_legacy.py new file mode 100644 index 00000000..4456799f --- /dev/null +++ b/third_party/DeepGEMM/tests/test_legacy.py @@ -0,0 +1,90 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes +) +from generators import ( + enumerate_m_grouped_contiguous, enumerate_k_grouped_contiguous, + generate_m_grouped_contiguous, generate_k_grouped_contiguous, +) + +def test_m_grouped_gemm_contiguous_tl() -> None: + print('Testing m-grouped contiguous Triton GEMM:') + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, _ in enumerate_m_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for expand in (False, True): + for test_alias in (False, True): + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + func_name = f"{'a_fused_' if expand else ''}m_grouped_bf16_gemm_{major_opt.lower() if test_alias else 'nt'}_contiguous_tl" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else b.mT + assert a[0].is_contiguous() and b[0].is_contiguous() + if expand: + m_row_indices = torch.arange(0, m, dtype=torch.int32, device='cuda') + getattr(deep_gemm.legacy, func_name)(a, b, d, (m_indices, m_row_indices)) + else: + getattr(deep_gemm.legacy, func_name)(a, b, d, m_indices) + d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.legacy.m_grouped_bf16_gemm_nt_contiguous_tl(a, b, d, m_indices) + + t = bench_kineto(test_func, 'm_grouped_bf16_gemm_contiguous_tl_impl', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_k_grouped_gemm_contiguous_tl() -> None: + print('Testing k-grouped contiguous Triton GEMM:') + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.bfloat16): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + + for fused_operand in ('a', 'b'): + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=False, use_bf16=True) + func_name = f"{fused_operand}_fused_k_grouped_bf16_gemm_{major_opt.lower()}_contiguous_tl" + k_indices = torch.arange(0, k, dtype=torch.int32, device='cuda') + k_start = torch.empty(len(ks), dtype=torch.int32, device='cuda') + k_end = torch.empty(len(ks), dtype=torch.int32, device='cuda') + for i, group_k in enumerate(ks): + k_start[i] = k_end[i-1] if i > 0 else 0 + k_end[i] = k_start[i] + group_k + getattr(deep_gemm.legacy, func_name)(a, b, c, (k_indices, k_start, k_end), True) + diff = calc_diff(c, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}' + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=False, use_bf16=True) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.legacy.b_fused_k_grouped_bf16_gemm_tn_contiguous_tl(a, b, c, (k_indices, k_start, k_end), True) + + t = bench_kineto(test_func, 'b_fused_k_grouped_bf16_gemm_contiguous_tl_impl', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_m_grouped_gemm_contiguous_tl() + test_k_grouped_gemm_contiguous_tl() diff --git a/third_party/DeepGEMM/tests/test_mega_moe.py b/third_party/DeepGEMM/tests/test_mega_moe.py new file mode 100644 index 00000000..83e8d622 --- /dev/null +++ b/third_party/DeepGEMM/tests/test_mega_moe.py @@ -0,0 +1,295 @@ +import argparse +import os +import random +import sys +import torch +import torch.distributed as dist +from typing import Tuple + +import deep_gemm +from deep_gemm.utils import per_token_cast_to_fp4, per_token_cast_to_fp8 +from deep_gemm.utils.dist import dist_print, init_dist, uneven_all_gather +from deep_gemm.testing import bench_kineto + + +def import_baseline(): + # Load legacy implements from third-party + deep_ep, tilelang_ops, do_bench, is_legacy_loaded = None, None, None, False + # noinspection PyBroadException + try: + import deep_ep + import importlib.util + from tilelang.profiler.bench import do_bench + spec = importlib.util.spec_from_file_location( + 'tilelang_ops', + os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'third-party', 'tilelang_ops', '__init__.py')) + tilelang_ops = importlib.util.module_from_spec(spec) + sys.modules['tilelang_ops'] = tilelang_ops + spec.loader.exec_module(tilelang_ops) + is_legacy_loaded = True + except Exception as ex: + dist_print(f'Failed to load legacy code: {ex}, skip baseline benchmarking', once_in_node=True) + dist_print(once_in_node=True) + return deep_ep, tilelang_ops, do_bench, is_legacy_loaded + + +# TODO: skip the test for SM90 +# noinspection PyUnboundLocalVariable,PyShadowingNames +def test(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + rank_idx, num_ranks, group = init_dist(local_rank, num_local_ranks) + torch.manual_seed(rank_idx) + random.seed(rank_idx) + + # Settings + num_max_tokens_per_rank = args.num_max_tokens_per_rank + num_tokens = max(0, args.num_max_tokens_per_rank - random.randint(0, args.num_max_removed_tokens)) \ + if args.num_tokens == 0 else args.num_tokens + hidden, intermediate_hidden = args.hidden, args.intermediate_hidden + num_experts, num_topk = args.num_experts, args.num_topk + num_experts_per_rank = num_experts // num_ranks + assert num_tokens <= num_max_tokens_per_rank + + # Allocate symmetric memory + buffer = deep_gemm.get_symm_buffer_for_mega_moe( + group, num_experts, + num_max_tokens_per_rank, num_topk, + hidden, intermediate_hidden + ) + + # Create inputs + # noinspection PyGlobalUndefined + def create_inputs(): + global x, topk_idx, topk_weights, l1_weights, l2_weights, transformed_l1_weights, transformed_l2_weights + global cumulative_local_expert_recv_stats_fused + global cumulative_local_expert_recv_stats_baseline + x = torch.randn((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + l1_weights = torch.randn( + (num_experts_per_rank, intermediate_hidden * 2, hidden), dtype=torch.bfloat16, device='cuda') + l2_weights = torch.randn( + (num_experts_per_rank, hidden, intermediate_hidden), dtype=torch.bfloat16, device='cuda') + scores = torch.randn((num_tokens, num_experts), dtype=torch.float, device='cuda') + topk_weights, topk_idx = torch.topk(scores, num_topk, dim=-1, largest=True, sorted=False) + cumulative_local_expert_recv_stats_fused = torch.randint( + 0, 100, (num_experts_per_rank, ), dtype=torch.int, device='cuda') + cumulative_local_expert_recv_stats_baseline = cumulative_local_expert_recv_stats_fused.clone() + if args.masked_ratio > 0: + rand_mask = torch.rand_like(topk_idx, dtype=torch.float) + topk_idx.masked_fill_(rand_mask < args.masked_ratio, -1) + topk_weights.masked_fill_(topk_idx < 0, 0) + + # Check SF requirements + assert hidden % 128 == 0 + assert intermediate_hidden % 128 == 0 + assert l1_weights.shape[2] % 128 == 0 and l2_weights.shape[2] % 128 == 0 + + # Cast inputs to FP8 with per-32 UE8M0 SF + x = per_token_cast_to_fp8(x, use_ue8m0=True, gran_k=32, use_packed_ue8m0=True) + + # Cast grouped BF16 weights to FP4 with MN-major SF + # TODO: merge with `cast_fp8_fp4_with_major` + def cast_grouped_weights_to_fp4(bf16_weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + num_groups, n, k = bf16_weights.shape + w = torch.empty((num_groups, n, k // 2), device='cuda', dtype=torch.int8) + w_sf = torch.empty((num_groups, n, k // 32), device='cuda', dtype=torch.float) + for i in range(num_groups): + w[i], w_sf[i] = per_token_cast_to_fp4(bf16_weights[i], use_ue8m0=True, gran_k=32) + w_sf = deep_gemm.transform_sf_into_required_layout(w_sf, n, k, (1, 32), num_groups) + return w, w_sf + + l1_weights = cast_grouped_weights_to_fp4(l1_weights) + l2_weights = cast_grouped_weights_to_fp4(l2_weights) + transformed_l1_weights, transformed_l2_weights = deep_gemm.transform_weights_for_mega_moe(l1_weights, l2_weights) + + # Run fused mega MoE + # NOTES: copy x into buffer before each call because debug mode zeros the entire buffer + def run_fused(): + buffer.x[:num_tokens].copy_(x[0]) + buffer.x_sf[:num_tokens].copy_(x[1]) + buffer.topk_idx[:num_tokens].copy_(topk_idx) + buffer.topk_weights[:num_tokens].copy_(topk_weights) + + y = torch.empty((num_tokens, hidden), dtype=torch.bfloat16, device='cuda') + # noinspection PyTypeChecker + deep_gemm.fp8_fp4_mega_moe( + y, + transformed_l1_weights, transformed_l2_weights, + buffer, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_fused, + activation_clamp=args.activation_clamp, + fast_math=bool(args.fast_math) + ) + return y, cumulative_local_expert_recv_stats_fused + + dist_print('Config:', once_in_node=True) + dist_print(f' > Tokens: {num_tokens}/{num_max_tokens_per_rank}', once_in_node=True) + dist_print(f' > Hidden: {hidden}', once_in_node=True) + dist_print(f' > Intermediate: {intermediate_hidden}', once_in_node=True) + dist_print(f' > Experts: {num_topk}/{num_experts}', once_in_node=True) + dist_print(f' > Buffer: {buffer.buffer.nbytes / 2 ** 30:.3f} GiB', once_in_node=True) + dist_print(once_in_node=True) + + # Only do NCU profiling + if args.ncu_profile_only: + create_inputs() + dist_print(f'Run fused kernel:', once_in_node=True) + run_fused() + dist_print(f' > Done, exiting', once_in_node=True) + + # Destroy and exit + dist.barrier() + buffer.destroy() + dist.destroy_process_group() + return + + # Non-overlapped baseline: EP dispatch + GEMM + EP combine + deep_ep, tilelang_ops, tilelang_bench, is_legacy_loaded = import_baseline() + alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout() + deep_gemm.set_mk_alignment_for_contiguous_layout(alignment) + ep_buffer = deep_ep.ElasticBuffer( + group, + num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden, + num_topk=num_topk, use_fp8_dispatch=True, + explicitly_destroy=True, + allow_multiple_reduction=False, + num_gpu_timeout_secs=10, num_cpu_timeout_secs=30 + ) if is_legacy_loaded else None + + def run_baseline(): + recv_x, _, recv_topk_weights, handle, _ = ep_buffer.dispatch( + x, topk_idx=topk_idx, topk_weights=topk_weights, + cumulative_local_expert_recv_stats=cumulative_local_expert_recv_stats_baseline, + num_experts=num_experts, expert_alignment=alignment, + do_cpu_sync=False, do_handle_copy=False, + do_expand=True, use_tma_aligned_col_major_sf=True, + ) + n = recv_x[0].size(0) + l1_y = torch.empty((n, intermediate_hidden * 2), dtype=torch.bfloat16, device='cuda') + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + recv_x, l1_weights, l1_y, handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, recipe=(1, 1, 32)) + # noinspection PyCallingNonCallable + l1_y = tilelang_ops.swiglu_apply_weight_to_fp8( + x=l1_y, + topk_weights=recv_topk_weights, + avail_tokens=handle.psum_num_recv_tokens_per_expert[-1], + num_per_channels=32, + use_col_major_scales=True, + round_scale=True, + ue8m0_scale=True, + output_bf16=False, + clamp_value=args.activation_clamp, + fast_math=bool(args.fast_math) + ) + l2_y = torch.empty((n, hidden), dtype=torch.bfloat16, device='cuda') + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous( + l1_y, l2_weights, l2_y, handle.psum_num_recv_tokens_per_expert, + use_psum_layout=True, recipe=(1, 1, 32)) + return ep_buffer.combine(l2_y, handle=handle)[0], cumulative_local_expert_recv_stats_baseline + + # Check correctness (must be bitwise identical) + num_correctness_tests = 1 if args.num_correctness_tests is None else args.num_correctness_tests + # noinspection PyBroadException + if is_legacy_loaded and num_correctness_tests > 0: + dist_print('Running correctness tests:', once_in_node=True) + for i in range(num_correctness_tests): + create_inputs() + for fused_result, baseline_result in zip(run_fused(), run_baseline()): + assert torch.equal(fused_result, baseline_result) + if (i + 1) % 100 == 0 or i == num_correctness_tests - 1: + dist_print(f' > Correctness test #{i + 1}/{num_correctness_tests} passed', once_in_node=True) + dist_print(once_in_node=True) + else: + create_inputs() + + # Count local received tokens + gathered_topk_idx = uneven_all_gather(topk_idx, group=group) + gathered_topk_idx[(gathered_topk_idx < rank_idx * num_experts_per_rank) | \ + (gathered_topk_idx >= (rank_idx + 1) * num_experts_per_rank)] = -1 + num_recv_tokens = (gathered_topk_idx != -1).sum().item() + + # Benchmark + t_fused = bench_kineto( + run_fused, 'mega_moe', + barrier=lambda: ep_buffer.barrier(use_comm_stream=False) if ep_buffer else dist.barrier(), + trace_path=None if not args.dump_profile_traces else f'{args.dump_profile_traces}/mega_moe_rank{rank_idx}.json') + t_baseline = tilelang_bench(run_baseline, _n_warmup=5, _n_repeat=1, backend='cudagraph', return_mode='median') / 1e3 if is_legacy_loaded else 0 + + # TFLOPS: 3 matmuls (L1 left, L1 right, L2), each 2 * M * N * K + safe_div = lambda a, b: float('nan') if b == 0 else a / b + tflops = safe_div(2 * num_recv_tokens * (hidden * intermediate_hidden * 3) / 1e12, t_fused) + + # HBM bytes: weights (FP4 packed = 0.5 bytes) + activations (FP8 = 1 byte) + output (BF16 = 2 bytes) + num_touched_experts = torch.unique(gathered_topk_idx.flatten()).numel() - 1 # NOTES minus 1 to exclude "-1" + num_hbm_bytes = ( + num_touched_experts * intermediate_hidden * 2 * hidden // 2 + # L1 weights (FP4) + num_touched_experts * hidden * intermediate_hidden // 2 + # L2 weights (FP4) + num_recv_tokens * hidden + # L1 acts read (FP8) + num_recv_tokens * intermediate_hidden + # L1 output write (FP8) + num_recv_tokens * intermediate_hidden + # L2 acts read (FP8) + num_recv_tokens * hidden * 2 # L2 output write (BF16) + ) + hbm_gbs = safe_div(num_hbm_bytes / 1e9, t_fused) + + # NVLink bytes: dispatch pull + combine write-back + num_nvlink_bytes = num_recv_tokens * hidden * 3 + nvlink_gbs = safe_div(num_nvlink_bytes / 1e9, t_fused) + + # Combine reduction (serial) time approximation + t_reduction = num_tokens * hidden * 2 * (1 + num_topk) / 6.5e12 + + # Summary + approx_factor = t_fused / (t_fused - t_reduction) + dist_print('Performance:', once_in_node=True) + dist_print(f' > EP: {rank_idx:2}/{num_ranks} | ' + f'{tflops:4.0f} TFLOPS | ' + f'overlap: ' + f'{tflops * approx_factor:4.0f} TFLOPS, ' + f'HBM {hbm_gbs * approx_factor:4.0f} GB/s, ' + f'NVL {nvlink_gbs * approx_factor:3.0f} GB/s | ' + f'{t_fused * 1e6:4.0f} us, ' + f'reduction: {t_reduction * 1e6:4.1f} us | ' + f'{safe_div(t_baseline, t_fused):.2f}x legacy') + + # Exit + dist.barrier() + buffer.destroy() + ep_buffer.destroy() if is_legacy_loaded else None + dist.destroy_process_group() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Test PyTorch symmetric memory') + + # Resource settings + parser.add_argument('--ncu-profile-only', action='store_true', help='Only run profiling without correctness test') + parser.add_argument('--num-processes', type=int, default=8, help='Number of processes to spawn (default: 8)') + + # Model settings + parser.add_argument('--num-max-tokens-per-rank', type=int, default=8192, help='Number of maximum tokens per rank') + parser.add_argument('--num-tokens', type=int, default=0, help='Number of tokens per rank (follow max minus removed if 0)') + parser.add_argument('--num-max-removed-tokens', type=int, default=0, help='Maximum number of tokens to remove') + parser.add_argument('--hidden', type=int, default=7168, help='Hidden size') + parser.add_argument('--intermediate-hidden', type=int, default=3072, help='Intermediate hidden size') + parser.add_argument('--activation-clamp', type=float, default=10, help='Clamp value for activation') + parser.add_argument('--num-experts', type=int, default=384, help='Number of experts') + parser.add_argument('--num-topk', type=int, default=6, help='Number of expert selections') + parser.add_argument('--masked-ratio', type=float, default=0.0, help='Mask some expert selections') + parser.add_argument('--fast-math', type=int, default=1, help='Enable fast math (0 or 1, default: 1)') + + # Test settings + parser.add_argument('--num-correctness-tests', type=int, default=None, help='Pressure test') + parser.add_argument('--dump-profile-traces', type=str, default='', help='Dump profiling trace JSONs') + parser.add_argument('--local-rank-idx', type=int, default=None, help='Run as single process with this local rank (e.g. for NCU prof)') + args = parser.parse_args() + + # Create dump trace directories + if args.dump_profile_traces: + os.makedirs(args.dump_profile_traces, exist_ok=True) + + if args.local_rank_idx is not None: + # Single-process mode: each process is launched separately (e.g. by NCU) + test(args.local_rank_idx, args.num_processes, args) + else: + # Launch tests + num_processes = args.num_processes + torch.multiprocessing.spawn(test, args=(num_processes, args), nprocs=num_processes) diff --git a/third_party/DeepGEMM/tests/test_sanitizer.py b/third_party/DeepGEMM/tests/test_sanitizer.py new file mode 100644 index 00000000..75ab10e6 --- /dev/null +++ b/third_party/DeepGEMM/tests/test_sanitizer.py @@ -0,0 +1,79 @@ +import argparse +import importlib +import inspect +import os +import subprocess +import sys + +import deep_gemm + + +# Single test template +script_dir = os.path.dirname(os.path.abspath(__file__)) +test_template = """ +import random +import sys +import torch + +# Necessary for `generators.py` +sys.path.append('{script_dir}') + +torch.manual_seed(0) +random.seed(0) + +from {module_name} import {func_name} +{func_name}() +""" + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--funcs', type=str, default='all') + parser.add_argument('--tools', type=str, default='memcheck,synccheck') + args = parser.parse_args() + + if args.funcs != 'all': + funcs = [] + for name in [x.strip() for x in args.funcs.split(',')]: + module_name, func_name = name.split('.') + funcs.append((module_name, func_name)) + else: + # Get all test functions except those related to cuBLAS + files = [f for f in os.listdir(script_dir) if f.endswith('.py')] + exclude_files = ['test_sanitizer.py', 'generators.py', 'test_mega_moe.py'] + funcs = [ + (module_name, name) + for module_name in [os.path.splitext(f)[0] for f in files if f not in exclude_files] + for name, obj in inspect.getmembers(importlib.import_module(module_name)) + if inspect.isfunction(obj) and name.startswith('test') and 'test_filter' not in name + ] + tools = [x.strip() for x in args.tools.split(',')] + + env = os.environ.copy() + env['CUDA_LAUNCH_BLOCKING'] = '1' + env['DG_JIT_PTXAS_CHECK'] = '1' + env['DG_USE_NVIDIA_TOOLS'] = '1' + env['DG_USE_TEMP_CUBLASLT_WORKSPACE'] = '1' # Avoid holding CUDA tensor that crashes during shutdown + env['PYTORCH_NO_CUDA_MEMORY_CACHING'] = '1' + env['TORCH_SHOW_CPP_STACKTRACES'] = '1' + + print(f'Library path: {deep_gemm.__path__}') + for module_name, func_name in funcs: + for tool in tools: + cmd = [ + '/usr/local/cuda/bin/compute-sanitizer', + f'--tool={tool}', + '--target-processes=application-only', + '--destroy-on-device-error=context', + '--force-blocking-launches', + '--check-api-memory-access=no', + '--kernel-name-exclude', 'kns=nvjet', + 'python', + '-c', + test_template.format(module_name=module_name, func_name=func_name, script_dir=script_dir) + ] + print(f'\n{"=" * 60}') + print(f'Running {module_name}.{func_name} with compute-sanitizer {tool}') + result = subprocess.run(cmd, env=env) + if result.returncode != 0: + sys.exit(result.returncode) diff --git a/third_party/DeepGEMM/third-party/tilelang_ops/__init__.py b/third_party/DeepGEMM/third-party/tilelang_ops/__init__.py new file mode 100644 index 00000000..47f9665e --- /dev/null +++ b/third_party/DeepGEMM/third-party/tilelang_ops/__init__.py @@ -0,0 +1 @@ +from .swiglu_apply_weight_to_fp8 import swiglu_apply_weight_to_fp8 diff --git a/third_party/DeepGEMM/third-party/tilelang_ops/swiglu_apply_weight_to_fp8.py b/third_party/DeepGEMM/third-party/tilelang_ops/swiglu_apply_weight_to_fp8.py new file mode 100644 index 00000000..7dc1dfc3 --- /dev/null +++ b/third_party/DeepGEMM/third-party/tilelang_ops/swiglu_apply_weight_to_fp8.py @@ -0,0 +1,212 @@ +import deep_gemm +import tilelang +import torch +from math import gcd +from tilelang import language as T + +from .utils import get_sf_and_inv, get_sf_shape + + +@tilelang.jit +def _swiglu_apply_weight_to_fp8_tl( + half_hidden: int, + num_per_channels: int, + use_col_major_scales: bool, + round_scale: bool, + ue8m0_scale: bool, + num_ctas: int, + has_topk_weights: bool, + has_avail_tokens: bool, + has_clamp_value: bool, + output_bf16: bool, + fast_math: bool, +) -> None: + in_dtype = T.bfloat16 + w_dtype = T.float32 + out_dtype = T.float8_e4m3fn + out_sf_dtype = T.uint8 if ue8m0_scale else T.float32 + num_dtype = T.int32 + num_tokens = T.dynamic("num_tokens") + assert half_hidden % 16 == 0 + if not output_bf16: + assert half_hidden % num_per_channels == 0 + + num_block_h = max(half_hidden // gcd(half_hidden, 16 * 1024), 1) + assert num_block_h <= num_ctas, "not supported hidden size" + blk_h = half_hidden // num_block_h + + layout_h = blk_h // 16 + + blk_n = 1024 // layout_h + + def local_layout(i, j): # noqa: ANN001 + thread_id = i * layout_h + j // 16 + local_id = j % 16 + return thread_id, local_id + + def local_layout_3d(i, j, k): # noqa: ANN001 + return local_layout(i, j * num_per_channels + k) + + @T.macro + def main( + bi: int, + bh: int, + num_ctas: int, + x: T.Tensor[(num_tokens, half_hidden * 2), in_dtype], # type: ignore + topk_weights: T.Tensor[num_tokens, w_dtype], # type: ignore + avail_tokens: T.Tensor[1, num_dtype], # type: ignore + out: T.Tensor[(num_tokens, half_hidden), out_dtype], # type: ignore + out_sf: T.Tensor[get_sf_shape(num_tokens, half_hidden, num_per_channels, ue8m0_scale, use_col_major_scales), out_sf_dtype], # type: ignore + out_bf16: T.Tensor[(num_tokens, half_hidden), T.bfloat16], # type: ignore + clamp_value: T.float32, + ): + gate_frag = T.alloc_fragment((blk_n, blk_h), T.float32) + up_frag = T.alloc_fragment((blk_n, blk_h), T.float32) + y_frag = T.alloc_fragment((blk_n, blk_h // num_per_channels, num_per_channels), T.float32) + y_f8_frag = T.alloc_fragment((blk_n, blk_h), out_dtype) + T.annotate_layout( + { + gate_frag: T.Fragment(gate_frag.shape, forward_fn=local_layout), + up_frag: T.Fragment(up_frag.shape, forward_fn=local_layout), + y_frag: T.Fragment(y_frag.shape, forward_fn=local_layout_3d), + y_f8_frag: T.Fragment(y_f8_frag.shape, forward_fn=local_layout), + } + ) + + T.assume(0 <= bh * blk_h + blk_h <= half_hidden) + + for i, j in T.Parallel(blk_n, blk_h): + gate_frag[i, j] = x[bi + i * num_ctas, bh * blk_h + j] + for i, j in T.Parallel(blk_n, blk_h): + up_frag[i, j] = x[bi + i * num_ctas, half_hidden + bh * blk_h + j] + + topk_weight = T.alloc_fragment((blk_n,), T.float32) + for i in T.Parallel(blk_n): + topk_weight[i] = topk_weights[bi + i * num_ctas] if has_topk_weights else 1.0 + + zero = T.alloc_var(T.float32, 0.0) + for i, j in T.Parallel(blk_n, blk_h): + if has_clamp_value: + up_frag[i, j] = T.min(clamp_value, T.max(-clamp_value, up_frag[i, j])) + gate_frag[i, j] = T.min(clamp_value, gate_frag[i, j]) + y_frag[i, j // num_per_channels, j % num_per_channels] = ( + gate_frag[i, j] / (1 + T.exp(-gate_frag[i, j])) * up_frag[i, j] * topk_weight[i] + zero + ) # HACK : + 0 for vectorize + + y_max_frag = T.alloc_fragment((blk_n, blk_h // num_per_channels), T.float32) + sf_inv_frag = T.alloc_fragment((blk_n, blk_h // num_per_channels), T.float32) + T.reduce_absmax(T.reshape(y_frag, (blk_n, blk_h // num_per_channels, num_per_channels)), y_max_frag) + for i, j in T.Parallel(blk_n, blk_h // num_per_channels): + clamped_amax = T.max(y_max_frag[i, j], 1e-4) + sf, sf_inv = get_sf_and_inv(clamped_amax, round_scale, ue8m0_scale) + i_index = bi + i * num_ctas + j_index = blk_h // num_per_channels * bh + j + # Store SF + if ue8m0_scale: + out_sf[j_index // 4, i_index * 4 + j_index % 4] = sf + elif use_col_major_scales: + out_sf[j_index, i_index] = sf + else: + out_sf[i_index, j_index] = sf + sf_inv_frag[i, j] = sf_inv + + for i, j in T.Parallel(blk_n, blk_h): + y_f8_frag[i, j] = y_frag[i, j // num_per_channels, j % num_per_channels] * sf_inv_frag[i, j // num_per_channels] + + for i, j in T.Parallel(blk_n, blk_h): + out[bi + i * num_ctas, blk_h * bh + j] = y_f8_frag[i, j] + + if output_bf16: + for i, j in T.Parallel(blk_n, blk_h): + out_bf16[bi + i * num_ctas, blk_h * bh + j] = y_frag[i, j // num_per_channels, j % num_per_channels] + + @T.prim_func + def _swiglu_apply_weight_to_fp8( + x: T.Tensor[(num_tokens, half_hidden * 2), in_dtype], # type: ignore + topk_weights: T.Tensor[num_tokens, w_dtype], # type: ignore + avail_tokens: T.Tensor[1, num_dtype], # type: ignore + out: T.Tensor[(num_tokens, half_hidden), out_dtype], # type: ignore + out_sf: T.Tensor[get_sf_shape(num_tokens, half_hidden, num_per_channels, ue8m0_scale, use_col_major_scales), out_sf_dtype], # type: ignore + out_bf16: T.Tensor[(num_tokens, half_hidden), T.bfloat16], # type: ignore + clamp_value: T.float32, + ): + # we actually don't use this + _ = num_tokens + # simplest schedule: one token one block, but pipelined as persistent + with T.Kernel(num_ctas, threads=1024) as cta_id: + avail_tokens_l = avail_tokens[0] if has_avail_tokens else num_tokens + T.pdl_sync() # avail_tokens must be const. + T.assume(0 <= avail_tokens_l <= num_tokens) + thread_idx = T.get_thread_binding() + new_num_ctas = num_ctas // num_block_h + if cta_id >= new_num_ctas * num_block_h: + T.thread_return() + for bi in T.serial(cta_id // num_block_h, avail_tokens_l - thread_idx // layout_h * new_num_ctas, new_num_ctas * blk_n): + main(bi, cta_id % num_block_h, new_num_ctas, x, topk_weights, avail_tokens, out, out_sf, out_bf16, clamp_value) + + return _swiglu_apply_weight_to_fp8 + + +def swiglu_apply_weight_to_fp8( + x: torch.Tensor, + topk_weights: torch.Tensor | None, + avail_tokens: torch.Tensor | None, + num_per_channels: int, + use_col_major_scales: bool, + round_scale: bool, + ue8m0_scale: bool, + clamp_value: float | None = None, + fmt: str = "e4m3", + num_sms: int | None = None, + output_bf16: bool = False, + fast_math: bool = True, +) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + assert fmt == "e4m3" + if num_sms is None: + num_sms = deep_gemm.get_num_sms() + + num_tokens, hidden_size = x.shape + assert hidden_size % (2 * num_per_channels) == 0 + + y = torch.empty( + (num_tokens, hidden_size // 2), + device=x.device, + dtype=torch.float8_e4m3fn, + ) + y_sf = torch.empty( + get_sf_shape(num_tokens, hidden_size // 2, num_per_channels, ue8m0_scale, use_col_major_scales), + device=x.device, + dtype=(torch.uint8 if ue8m0_scale else torch.float32), + ) + + y_bf16 = torch.empty((num_tokens, hidden_size // 2), device=x.device, dtype=torch.bfloat16) if output_bf16 else None + + if num_tokens > 0: + _swiglu_apply_weight_to_fp8_tl.pass_configs = { + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: fast_math, + } + kernel = _swiglu_apply_weight_to_fp8_tl( + hidden_size // 2, + num_per_channels, + use_col_major_scales, + round_scale, + ue8m0_scale, + num_sms, + topk_weights is not None, + avail_tokens is not None, + clamp_value is not None, + output_bf16, + fast_math + ) + kernel(x, topk_weights, avail_tokens.view(1) if avail_tokens is not None else None, y, y_sf, y_bf16, clamp_value or 0.0) + + if ue8m0_scale: + if num_tokens == 0: + y_sf.as_strided_(y_sf.size(), (0, 1)) + y_sf = y_sf.view(dtype=torch.int32) + if output_bf16: + return y, y_sf.T[:num_tokens], y_bf16 + else: + return y, y_sf.T[:num_tokens] diff --git a/third_party/DeepGEMM/third-party/tilelang_ops/utils.py b/third_party/DeepGEMM/third-party/tilelang_ops/utils.py new file mode 100644 index 00000000..3f3ca46d --- /dev/null +++ b/third_party/DeepGEMM/third-party/tilelang_ops/utils.py @@ -0,0 +1,47 @@ +from typing import Any +from tilelang import language as T + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def get_sf_shape( + num_tokens: int, + hidden: int, + num_per_channels: int, + use_ue8m0: bool, + use_col_major_sf: bool, +) -> tuple[int, int]: + num_scales = ceil_div(hidden, num_per_channels) + num_scales = ceil_div(num_scales, 4) if use_ue8m0 else num_scales + + # For col-major SF, TMA must be aligned into 16 bytes + # For UE8M0, we must use col-major SF, and 4 UE8M0 are expanded into the inner dim (token) + num_sf_tokens = num_tokens + if use_col_major_sf: + num_sf_tokens = align(num_tokens, 4) + num_sf_tokens = num_sf_tokens * 4 if use_ue8m0 else num_sf_tokens + + return (num_scales, num_sf_tokens) if use_col_major_sf else (num_sf_tokens, num_scales) + + +def get_sf_and_inv(amax: float, round_sf: bool, use_ue8m0: bool) -> tuple[Any, Any]: + sf = amax / 448.0 + if not round_sf: + return sf, 448.0 / amax + + # Round into 2's power + bits = T.reinterpret("uint32", sf) + exp = (bits >> 23) & 0xFF + man_bits = bits & ((1 << 23) - 1) + exp_scale = T.reinterpret("int32", exp - 127 + (man_bits != 0)) + if use_ue8m0: # noqa: SIM108 + sf = T.Cast("uint8", exp_scale + 127) + else: + sf = T.reinterpret("float", (127 + exp_scale) << 23) + return sf, T.reinterpret("float", (127 - exp_scale) << 23) \ No newline at end of file