From 0ed3b949d0427265e5e832a9f5e7f7d15f5e3eda Mon Sep 17 00:00:00 2001 From: Chenggang Zhao Date: Mon, 29 Sep 2025 17:10:12 +0800 Subject: [PATCH] Update README --- README.md | 55 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index a80558d..0c20869 100644 --- a/README.md +++ b/README.md @@ -8,11 +8,13 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert ## News +- 2025.09.28: DeepGEMM now supports scoring kernels (weighted ReLU MQA logits) for the lightning indexer for DeepSeek v3.2. + - Please see [#200](https://github.com/deepseek-ai/DeepGEMM/pull/200) for more details. - 2025.07.20: DeepGEMM now supports both SM90/SM100, and has a full refactor with a low-CPU-overhead JIT CPP module. - - NVRTC and post-compilation SASS optimization are all disabled - - NVRTC will be supported later - - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported - - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details + - NVRTC and post-compilation SASS optimization are all disabled. + - NVRTC will be supported later. + - As NVCC 12.9 will automatically do the FFMA interleaving, all post optimizations will be no longer supported. + - Please see [#112](https://github.com/deepseek-ai/DeepGEMM/pull/112) for more details. - 2025.05.14: DeepGEMM now offers weight gradient kernels for dense and MoE backward! See [#95](https://github.com/deepseek-ai/DeepGEMM/pull/95) for details. - 2025.05.07: DeepGEMM now supports NVRTC with up to 10x compilation speedup! See [#94](https://github.com/deepseek-ai/DeepGEMM/pull/94) for details. Please use `DG_JIT_USE_NVRTC=1` to enable it (may have performance loss with some cases). - 2025.04.18: DeepGEMM now achieves up to **1550 TFLOPS** on H800! See [#74](https://github.com/deepseek-ai/DeepGEMM/pull/74), [#78](https://github.com/deepseek-ai/DeepGEMM/pull/78), [#81](https://github.com/deepseek-ai/DeepGEMM/pull/81), [#86](https://github.com/deepseek-ai/DeepGEMM/pull/86) and [340d988](https://github.com/deepseek-ai/DeepGEMM/commit/340d9880f4a418d943d34260d20a79f41f4c0526) for details. @@ -46,9 +48,9 @@ Despite its lightweight design, DeepGEMM's performance matches or exceeds expert - Python 3.8 or higher - Compilers with C++20 support - CUDA Toolkit: - - CUDA 12.3 or higher for SM90 - - **We highly recommend 12.9 or higher for the best performance** - - CUDA 12.9 or higher for SM100 + - CUDA 12.3 or higher for SM90 + - **We highly recommend 12.9 or higher for the best performance** + - CUDA 12.9 or higher for SM100 - PyTorch 2.1 or higher - CUTLASS 4.0 or higher (could be cloned by Git submodule) - `{fmt}` library (could be cloned by Git submodule) @@ -66,6 +68,7 @@ cat develop.sh # Test all GEMM implements python tests/test_layout.py +python tests/test_attention.py python tests/test_bf16.py python tests/test_fp8.py python tests/test_lazy_init.py @@ -109,6 +112,30 @@ During the inference decoding phase, when CUDA graph is enabled and the CPU is u Use `m_grouped_fp8_gemm_nt_masked` for this purpose and consult the relevant documentation. An example usage is to use the output of low-latency kernels from [DeepEP](https://github.com/deepseek-ai/DeepEP) as input. +#### V3.2 MQA kernels for the indexer + +The kernel family has two versions, non-paged (for prefilling) and paged (for decoding). +Take the non-paged version `fp8_mqa_logits` as an example. It has 6 inputs: + +- `q`, E4M3 tensor with shape `[seq_len, num_heads, head_dim]` +- `kv`, E4M3 tensor (shaped as `[seq_len_kv, head_dim]`) with float SF (shaped as `[seq_len_kv]`) +- `weights`, float tensor with shape `[seq_len, num_heads]` +- `cu_seq_len_k_start` and `cu_seq_len_k_end`, int tensor with shape `[seq_len]` +- `clean_logits`, whether to clean the unfilled logits into `-inf` + +The output tensor is shaped as `[seq_len, seq_len_kv]`, indicating token-to-token logits. +For each token `i` in `q`, it will iterate all tokens `j` from `[cu_seq_len_k_start[i], cu_seq_len_k_end[i])`, +and calculate the logit `out[i, j]` as: + +```python +kv_j = kv[0][j, :] * kv[1][j].unsqueeze(1) # [head_dim] +out_ij = q[i, :, :] @ kv_j # [num_heads] +out_ij = out_ij.relu() * weights[i, :] # [num_heads] +out_ij = out_ij.sum() # Scalar +``` + +For more details and the paged version `fp8_paged_mqa_logits`, please refer to `tests/test_attention.py`. + #### Utilities The library provides some utility functions besides the above kernels: @@ -127,17 +154,17 @@ The library provides some utility functions besides the above kernels: The library also provides some environment variables, which may be useful: - General - - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default + - `DG_JIT_DEBUG`: `0` or `1`, print more JIT debugging information, `0` by default - JIT cache related - - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default + - `DG_JIT_CACHE_DIR`: string, the cache directory to store compiled kernels, `$HOME/.deep_gemm` by default - NVCC/NVRTC selections - - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default - - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default + - `DG_JIT_USE_NVRTC`: `0` or `1`, use NVRTC instead of NVCC, faster compilation but maybe have lower performance for some cases, `0` by default + - `DG_JIT_NVCC_COMPILER`: string, specified NVCC compiler path; will find in `torch.utils.cpp_extension.CUDA_HOME` by default - Compiler options - - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default - - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default + - `DG_JIT_PTXAS_VERBOSE`: `0` or `1`, show detailed PTXAS compiler output, `0` by default + - `DG_JIT_PRINT_COMPILER_COMMAND`: `0` or `1`, print NVCC compilation command, `0` by default - Heuristic selection - - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default + - `DG_PRINT_CONFIGS`: `0` or `1`, print selected configs for each shape, `0` by default For additional examples and details, please refer to [the test code](tests/test_core.py) or review the corresponding Python documentation.