Commit Graph

70 Commits

Author SHA1 Message Date
a554de8b24 fix: dispatch TMA byte counts for FP4 (kHidden/2), rename fp8→fp4 layout refs 2026-05-11 20:47:58 +00:00
b3d1aae038 feat: full FP4 activations for mxf4nvf4 - E2M1 packed A side + UE4M3 scales
mxf4nvf4 requires BOTH A and B to be FP4 (E2M1 packed).
Changes:
- a_dtype_t: float_e4m3_t → float_e2m1_unpacksmem_t
- UMMA_K: 32 → 64 (FP4 MMA atom)
- L1 epilogue: FP8 quant → E2M1 FP4 quantization with nearest-neighbor
- L1 output SMEM: packed E2M1 (2 per byte), TMA store uint8
- TMA descriptors: adjusted for FP4 packing (K/2 bytes per row)
- SymmBuffer: uint8 activations, shape (M, K//2)
- Staging kernel: BF16 → E2M1 packed + UE4M3 block16 scales
2026-05-11 20:29:08 +00:00
fbdddaccf4 revert: restore mxf4nvf4/block16 code (correct path for sm_100a)
Reverted to commit 36b439e's NVFP4 kernel code:
- kGranK=16, mxf4nvf4.block_scale.scale_vec::4X
- float_ue4m3_t instruction descriptor
- Block16 SF layout (4X TMEM)
- UE4M3 L1 epilogue
- No UE4M3→UE8M0 conversion, no block16→block32 merge

The mxf4nvf4.scale_vec::4X PTX instruction compiles successfully
on both sm_100 and sm_100f with CUDA 13.0. The previous build 17
error was likely from a different cause, not the arch flag.

Python: reverted transform_nvfp4_weights_for_mega_moe to use
pack_ue4m3_to_int32 with gran_k=16, no UE8M0 conversion.
2026-05-11 15:02:47 +00:00
03b8c99ee1 fix: use mxf8f6f4 (UE8M0) on SM100 — mxf4nvf4 requires SM103+
B200 (SM100) does NOT support kind::mxf4nvf4 at all (neither 2X nor 4X).
Only mxf8f6f4.block_scale with UE8M0 scales is available on SM100.

Strategy: keep NVFP4 E2M1 weights, convert UE4M3 block scales → UE8M0
in the weight transformation. This is a scale format adaptation for
hardware compatibility, not a format conversion.

Changes:
- Kernel: back to mxf8f6F4 instruction + float_ue8m0_t descriptor
- L1 epilogue: back to UE8M0 (>> 23) activation scales
- Python: merge block16→block32, convert UE4M3→float32→UE8M0
- Packing: uint8 (UE8M0) → int32, same as MXFP4
2026-05-11 09:28:45 +00:00
dcebe033e2 fix: use scale_vec::2X (block32) for SM100 B200 compatibility
scale_vec::4X (block16) requires SM103/SM120 (B300/GB300), not SM100 (B200).
Revert to block32 with UE4M3 scales. Same TMEM layout as MXFP4 but with
UE4M3 scale format instead of UE8M0.

Changes:
- kGranK: 16 → 32
- PTX: scale_vec::4X → scale_vec::2X
- SF layout: same as MXFP4 (K/32, K/128 for int32 packed)
- UTCCP: i*8 → i*4 (2X layout, same as MXFP4)
- TMEM columns: same as MXFP4 (SF_BLOCK_M/32, SF_BLOCK_N/32)
- Python: merge NVFP4 block16→block32 scales (max of adjacent pairs)
- recipe: (1,1,16) → (1,1,32)
2026-05-11 08:36:59 +00:00
36b439ee26 feat: NVFP4 mega MoE kernel (scale_vec::4X, UE4M3 block scales)
- New CUDA kernel: sm100_fp8_nvfp4_mega_moe_impl
  - kGranK=16 (NVFP4 group_size=16, vs MXFP4's 32)
  - kind::mxf4nvf4.block_scale.scale_vec::4X PTX instruction
  - float_ue4m3_t scale factor type in instruction descriptor
  - SF layout: scale_vec::4X (4 TMEM sub-columns per UMMA atom)
  - UTCCP column stride: i*8 (vs MXFP4's i*4) for 4X layout
  - L1 epilogue: UE4M3 activation scales (float→cutlass::float_e4m3_t)
  - SF loading: kNumSFUint32 = kHidden/64 (4 UE4M3 per int32)

- New PTX wrappers: SM100_MMA_MXF4NVF4_2x1SM_SS, SM100_MMA_MXF4NVF4_SS

- Python API:
  - fp8_nvfp4_mega_moe() with recipe=(1,1,16)
  - transform_nvfp4_weights_for_mega_moe() for UE4M3→int32 UTCCP packing
  - _pack_nvfp4_sf_for_utccp() helper

- C++ bindings:
  - mega_nvfp4.hpp with NVFP4-specific SymmBuffer (SF stride K/16)
  - JIT kernel header with kGranK=16 TMA descriptors
  - Registered in python_api.cpp

NOTE: Both SFA and SFB must use UE4M3 (scale_format_ is 1-bit, shared).
The L1 epilogue converts float→UE4M3 for activation scales.
2026-05-11 05:41:08 +00:00
Zhean Xu
891d57b4db Add various optimizations and Mega MoE benchmarks (#316)
* Merge with private repo

* Add Mega MoE Benchmark

* Minor fix

* Update

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2026-04-24 18:41:37 +08:00
Chenggang Zhao
7f2a703ed5 [Public release 26/04] Introducing Mega MoE, FP4 Indexer and other features/fixes (#304)
* Merge with private repo

* Update README

* Update README

* Update README

* Add PyTorch requirements

* Fix sync scopes for MQA logits (#256)

* Update README
2026-04-17 09:45:14 +08:00
Ray Wang
d30fc36c8f Fix sync issue of TMEM alloc/dealloc (#292) 2026-03-22 16:41:28 +08:00
Xin Qiu
35c4bc8771 fix: k_grouped_fp8_gemm_nt_contiguous crashes with n = 768 on H100 (#238) 2026-02-25 10:13:54 +08:00
Ray Wang
477618cd51 Fix a sync issue in SM100 MQA logits (#285) 2026-02-03 17:29:49 +08:00
Zhean Xu
0f5f266202 Multiple updates and refactorings (#280) 2026-01-16 17:06:52 +08:00
Ray Wang
38f8ef73a4 Multiple updates and refactorings (#231) 2025-11-21 17:49:47 +08:00
Zhean Xu
bb4424aad4 Fix sum_k * shape_m overflow 2025-11-19 11:51:36 +08:00
Ray Wang
ec5e9ed0b8 Fix SM90 MQA logits (#229) 2025-11-19 10:50:36 +08:00
Ray Wang
2f9d87877e Use larger MMA shape (#227) 2025-11-14 11:38:15 +08:00
Chenggang Zhao
c1bf4cae4b Fix version 2025-10-01 20:31:27 +08:00
Chenggang Zhao
07b82fb8cd Fix old CUDA compatibility 2025-10-01 20:29:15 +08:00
Simon Mo
59f2c07cf2 Add SM100 kernels (#201)
Signed-off-by: simon-mo <simon.mo@hey.com>
2025-09-29 17:07:28 +08:00
Chenggang Zhao
80ceeb2c76 Add SM90 kernels (#200) 2025-09-29 17:00:23 +08:00
Ray Wang
3f71de7aa9 Make various updates and fixes (#198) 2025-09-25 16:19:07 +08:00
zhonghui-J
2da871e304 Fix grouped gemms performance issue. (#168) 2025-08-22 17:35:43 +08:00
Chenggang Zhao
e38c2e3103 Remove comments 2025-08-22 17:32:04 +08:00
Chenggang Zhao
f20256fd50 Compatible with CUDA 13 2025-08-22 17:30:47 +08:00
xiweny
affdb1cd90 Add sm_100f support and make nvcc 13 happy (#157)
Signed-off-by: Xiwen Yu <13230610+VALLIS-NERIA@users.noreply.github.com>
2025-08-22 17:19:32 +08:00
Ray Wang
f85ec649d7 Make various updates and fixes: (#164)
- Add BF16 support for SM90 and SM100
- Refactor Python APIs
- Other fixes and code refactoring
2025-08-15 18:32:35 +08:00
Ray Wang
d9c363f86f Make various updates and fixes:
- Add support for legacy CUDA versions; now compatible with CUDA 12.3 and newer
- Add support for NVRTC compilation
- Other fixes and code refactoring
2025-08-02 19:52:22 -07:00
yukuai26
aff9da0aba Fix SM90 GEMM (#149)
* Fix sm90 GEMM

* Fix typo

---------

Co-authored-by: Kuai Yu <yukuai@deepseek.com>
2025-08-01 10:36:49 +08:00
Ray Wang
9da4a23561 Add more GPU architectures support (#112)
* Add more GPU architectures support

* Update layout.py

* Optimize performance, Add SM90 support, Add 1D2D SM100 support

* Add fmtlib submodule at commit 553ec11

---------

Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
2025-07-18 11:32:22 +08:00
shixianc
0c88cd0139 Fix illegal memory address when skipping -1 m indices (#113)
Co-authored-by: Shixian Cui <shixian@amazon.com>
2025-06-16 10:44:31 +08:00
yukuai26
8dfa329827 Grouped GEMM skip useless computation for unaligned Ms (#103)
* Grouped GEMM skip useless computation for unaligned Ms

* Update readme.md

* small typo

* Rename variables

* Restore previous indent

* Format

* Refactor tests

* Add `SkipComputation` types

* Bug fixed

* Format

* Fix tests

* Add assertions

* Minor fix

---------

Co-authored-by: yukuai <yukuai@deepseek.com>
Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2025-05-27 13:43:38 +08:00
Chenggang Zhao
78d8362e7a Add a missing #pragma once 2025-05-15 18:10:05 +08:00
Chenggang Zhao
816b39053a Refactor launch-related structures 2025-05-15 16:14:21 +08:00
Chenggang Zhao
e2d6a107ef Cleanup some useless staffs 2025-05-14 15:46:45 +08:00
Zhean Xu
04278f6dee Weight gradient kernels for dense and MoE models (#95)
* Init weight gradient kernels.

* Support unaligned n,k and gmem stride

* Update docs

* Several cleanups

* Remove restrictions on N

* Add stride(0) assertions

---------

Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2025-05-14 14:47:58 +08:00
Gabriel Wu
bfe983c4c2 Refactor JIT compilation (+NVRTC support) (#94)
* [wip] refactor: compile to .cubin

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* refactor: compile to .cubin and add NVRTC option

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* fix: compiler version

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: compat for old drivers

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: save kernel name to file

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: fix win compat

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* fix: windows compat

Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: make API more general

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* feat: drop support for CUDA<12.3

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* doc: update README

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>

* Some lints and refactor

* Refactor runtime

* Several fixes

* Refactor environment variables

* Code format

* Add a TODO

* Compatible with CUDA 12.3

* Fix indent

* Fix typing

* Drop support for Windows

* Add a TODO

---------

Signed-off-by: Zihua Wu <13583761+lucifer1004@users.noreply.github.com>
Signed-off-by: Gabriel Wu <13583761+lucifer1004@users.noreply.github.com>
Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2025-05-07 11:38:14 +08:00
yukuai26
95e81b3dd6 Indivisible TMA (#90)
Fix indivisible shapes for TMA multicast

---------

Co-authored-by: yukuai <yukuai@deepseek.com>
Co-authored-by: Chenggang Zhao <chenggangz@deepseek.com>
2025-04-23 14:55:14 +08:00
yukuai26
891f35adf5 Support TMA multicast on B with m_grouped_gemm_contiguous. (#88) 2025-04-21 09:43:17 +08:00
Chenggang Zhao
83aa960b9b Fix bugs 2025-04-18 11:55:51 +08:00
Chenggang Zhao
340d9880f4 Overlap TMA store 2025-04-18 11:18:23 +08:00
Zhean Xu
4499c4ccbb Refactor MMA template with CUTLASS (#87)
* Refactor MMA with cutlass

* Update README.md

---------

Co-authored-by: Zhean Xu <xza@deepseek.com>
2025-04-14 17:06:49 +08:00
Chenggang Zhao
37aa127451 Use swizzling instead of padding (#86)
* Add swizzling params

* Add TMA D descriptor

* Always use STSMx2

* Swizzling draft

* Compatible with padding

* Fix bugs

* Optimize swizzle performance

* Optimize expression

* Optimize TMA issues

* Fix README

* Stricter assertions
2025-04-14 15:20:58 +08:00
Chenggang Zhao
b0d64817a7 OOB bugs fixed 2025-04-11 11:00:47 +08:00
Chenggang Zhao
99eb6ec563 Remove useless STSM 2025-04-11 10:45:36 +08:00
Chenggang Zhao
8041ed7164 Use 1D TMA store 2025-04-11 10:42:01 +08:00
Chenggang Zhao
a77009cb14 Make partition pipelined 2025-04-10 18:07:25 +08:00
Chenggang Zhao
5bda27244b Add CMake support for CLion indexing 2025-04-10 09:57:54 +08:00
Chenggang Zhao
5a80e4bb96 Fix indent x2 2025-04-09 11:00:10 +08:00
Chenggang Zhao
bdca8b0624 Fix indent 2025-04-09 10:59:07 +08:00
Chenggang Zhao
4c0cc290c7 Refactor M repetition with loops 2025-04-09 10:50:44 +08:00