Files
nvfp4-megamoe-kernel/README.md

18 KiB
Raw Blame History

DeepSeek-V4 NVFP4 Kernel Suite

CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use cutlass.cute (CuTeDSL) with Blackwell tensor cores.

Status (May 21, 2026 — 06:45 UTC)

Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM — COMPLETE

File: tests/test_stage_a_v2.py Result: Q(128,128) @ K^T(128,128) → S(128,128), cosine 0.999999

🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS

Pipeline deadlock: FIXED. Kernel runs without deadlock. Bug 1 (V MN-major): Fix applied. Bug 2 (softmax packing): Confirmed correct (V=I test: cosine 1.0). Bug 3 (ACCUMULATE): Fix applied. Bug 4 (non-square PV): 🔨 ACTIVE — Two approaches attempted, both blocked.

Bug 4 (CURRENT): PV MMA Broken for (128,64) Output

Root Cause: The (128,64) PV MMA's A-fragment reads P from TMEM with a different layout than the softmax packing writes it.

The softmax packing writes P using the QK C-fragment layout (MMA atom = (128,128,16), N_MMA=128). The PV MMA reads P using its A-fragment layout (MMA atom = (128,64,16), N_MMA=64). These two layouts produce different physical TMEM addresses for the same logical (m,k) coordinate.

Evidence:

  • Truncated identity V (64,128) MN-major: O[m,d] ≈ P[m, 2d] — the MMA reads every other column of P
  • All-ones V: cosine 0.999999 (uniform data hides the layout mismatch)
  • Single-element V: cosine 1.0 (sparse data also hides it)
  • (128,128) PV with same softmax packing: cosine 0.999999 (N_MMA=128 matches QK, no mismatch)

C++ TMEM Fragment Layout (from mma_traits_sm100.hpp):

Layout tmem_atom = Layout<Shape <_128, Int<N_MMA>>,
                          Stride<  _1,     _128>>{};
  • QK C-fragment: N_MMA=128 → 128 TMEM columns, stride 128
  • PV A-fragment (128,64): N_MMA=64 → 64 TMEM columns, stride 128

Approach 1: (128,64) PV MMA — BLOCKED (v28, deadlocks)

  • Created PV MMA with pv_mma_tiler[:2] = (128, 64) as FMHA does
  • Kernel compiles but deadlocks at runtime inside epilogue_tma_store
  • The (128,64) PV MMA changes tOtO shape from (128,128) to (128,64), affecting TMEM allocation, epilogue partitioning, and tCgC partitioning
  • The deadlock is NOT in the MMA or softmax — it's specifically in epilogue_tma_store which calls acc_pipeline.consumer_wait()
  • Diagnostics show: MMA warp stuck at mma_si acquire (waiting for softmax), EPI warps complete softmax but deadlock in epilogue
  • All three known deadlock fixes are applied (no cta_layout_vmnk on mma_si, TMA warp doesn't call wait_for_alloc, PipelineTmaStore)
  • New deadlock root cause unknown — likely related to epilogue partitioning mismatch with (128,64) output shape

Approach 2: Pad V to (128,128), use (128,128) PV MMA — BLOCKED (v29, deadlocks)

  • Pad V from (64,128) to (128,128) with zeros, keep (128,128) PV MMA
  • Should produce O=(128,128) where first 64 columns are correct
  • Also deadlocks, even with V=I(128,128) which works in test_pv_diag.py
  • test_pv_diag.py uses the exact same pipeline and kernel structure
  • Difference between v29 and test_pv_diag: likely a subtle code issue (b_dtype vs q_dtype, pv_mma_tiler, etc.)
  • Needs bisecting — slowly modify test_pv_diag into v29 to find the breaking change

Approach 3 (NOT YET TRIED): FMHA-style (128,16) PV MMA with N-tiling

  • FMHA uses pv_mma_tiler = (128, 16, 128) with MN = (128, 16)
  • For head_dim=64, FMHA tiles the N dimension 4 times (64/16=4)
  • Requires restructuring the kernel to loop over N tiles
  • This is the "correct" approach but requires more code changes

Approach 4 (NOT YET TRIED): Fix softmax packing to write P in PV A-fragment layout

  • Instead of composing tStS.layout → (128, 64), write P using a layout derived from the PV A-fragment
  • FMHA does this: tStS_P_layout = cute.composition(tStS.layout, (128, tilePlikeFP32))
  • The composition writes 64 packed FP32 columns that alias the PV A-fragment's 64 TMEM columns
  • This should fix the alias for (128,64) PV MMA, but still blocked by the epilogue deadlock

V SMEM Layouts (confirmed correct):

  • PV(128,64) V SMEM: outer=((64,16),1,8,1):((1,64),0,1024,0), inner=S<3,4,3>
  • PV(128,128) V SMEM: outer=(((64,2),16),1,8,1):(((1,8192),64),0,1024,0), inner=S<3,4,3>

FMHA Softmax Packing Bridge (reference trace):

FMHA's softmax writes P in a packed format that aliases the PV A-fragment:

# P store columns: qk_mma_tiler[1] * p_dtype.width // qk_acc_dtype.width
# For 128 BF16 columns: 128 * 16 / 32 = 64 packed FP32 columns
tilePlikeFP32 = qk_mma_tiler[1] // Float32.width * o_dtype.width  # = 64

# Destination: packed-P physical TMEM view (64 FP32 columns, not 128)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + tmem_p0_offset, tStS_P_layout)

# Register packing: Float32 backing → BF16 recast view
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, Float32)
tTMEM_STORErS_x4_e = cute.make_tensor(
    cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=BFloat16),
    tTMEM_LOADrS.layout)  # 128 BF16 logical elements

The consumer (PV A-fragment) reads from the same TMEM columns:

tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
tOrP0 = cute.make_tensor(
    tOrP.iterator + Float32.width // BFloat16.width * tmem_p0_offset,
    tOrP.layout)

The subtle point: tilePlikeFP32 = 64 means "64 packed FP32 TMEM columns to store 128 logical BF16 P columns", NOT "PV output N=64". It's a coincidence that they're both 64 for head_dim=64.


Bug 1: V B-Operand Must Be MN-Major — FIX APPLIED

V must be shaped (head_dim, seq) = (64, 128) with strides (1, 64) — MN-major. PV MMA uses v_major (OperandMajorMode.MN) instead of b_major (K).

V must use as_strided — default PyTorch (64,128) gives strides (128,1) which is K-major.

Bug 2: C-Fragment Composition Store — CONFIRMED CORRECT

FP32→BF16 packing via C-fragment composition store (FMHA pattern) is correct. Proven by V=I test (cosine 1.0) and random V 128x128 test (cosine 0.999999).

FOOTGUN: St32x32bOp MUST use Float32, NOT BFloat16. ⚠️ The recast view for P packing uses the LOAD layout (128 BF16 elements), not the store composition shape.

Bug 3: First PV Must Use ACCUMULATE=False — FIX APPLIED

If ACCUMULATE=True on the first PV, O = P@V + old_O adds uninitialized TMEM. Always ACCUMULATE=False for first PV, then True for subsequent tiles.


🔨 Stage C: Online Softmax — AFTER B

Per the pseudocode: epilogue warps compute per-row tile_max, rescale, exp, store P back to TMEM.

🔨 Stage D: FP8 Paged KV Gather — AFTER C

Replace BF16 TMA load with FP8 paged KV gather + per-position dequant.


Pipeline Deadlock — FIXED (May 21)

v20-v25 all deadlocked on GPU. Three root causes found and fixed:

Fix 1: PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk

FMHA's mma_s0/mma_s1 PipelineUmmaAsync calls do NOT pass cta_layout_vmnk. Removing it fixes the deadlock.

Fix 2: TMA Warp Must NOT Call tmem.wait_for_alloc()

The tmem allocation barrier has num_threads = 32 * (mma_warp + epilogue_warps). The TMA warp is NOT part of this barrier. Calling wait_for_alloc() from the TMA warp corrupts the barrier.

Fix 3: PipelineTmaStore (not TmaStorePipeline)

pipeline.TmaStorePipeline does not exist. The correct name is pipeline.PipelineTmaStore.


DEADLOCK FIX #4: num_tma_load_bytes Must Include V Bytes (May 21, 07:10 UTC)

Root cause: num_tma_load_bytes only accounted for Q + K, not V. The TMA barrier's tx-count underflowed when V bytes arrived, wrapping the 20-bit counter. The barrier never reached zero → MMA warp waits forever.

Fix: Add cute.size_in_bytes(self.b_dtype, v_smem) to num_tma_load_bytes.

Why it was sneaky: test_pv_diag.py had the same bug but ran fine because V=I(128,128) is small enough that the race was benign. v29 with larger or differently-strided V exposed it.

Remaining issue with v28 (128,64 PV MMA): The (128,64) PV MMA itself still needs investigation — the softmax-to-PV TMEM alias must be adapted for the different A-fragment layout.


DEAD TEST: test_stage_b_v21.py — DELETED, DO NOT RECREATE

v21 attempted both Bug 1 and Bug 2 fixes in a hand-rolled pipeline kernel. It deadlocks on GPU. Root cause: pipeline synchronization mismatch. Do not recreate. Write from scratch using fmha.py as the reference.


FOOTGUNS — CUTLASS CuTeDSL Landmines

🔴🔴🔴 0. num_tma_load_bytes MUST Include ALL TMA-Loaded Tensors (Q + K + V) — DEADLOCK IF MISSING

This is the #1 landmine. It cost us hours of debugging.

PipelineTmaUmma.create() takes a tx_count parameter (via num_tma_load_bytes) that tells the TMA barrier how many bytes to expect. If you load Q, K, and V via TMA but only budget Q+K bytes, the barrier's tx-count underflows when V's bytes arrive. On SM100 the 20-bit tx-count wraps, the barrier never reaches zero, and the consumer (MMA warp) waits forever.

# ❌ WRONG — missing V bytes → DEADLOCK
self.num_tma_load_bytes = (
    cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)

# ✅ CORRECT — include ALL three TMA loads
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
    cute.size_in_bytes(self.q_dtype, a_smem)
    + cute.size_in_bytes(self.b_dtype, b_smem)
    + cute.size_in_bytes(self.b_dtype, v_smem)  # ← DO NOT FORGET THIS
) * cute.size(qk_mma.thr_id.shape)

Why it's sneaky: With small V tensors, the race might be benign (V completes before the consumer reads). With larger V, the underflow actually traps the barrier. So it can work on small tests and deadlock on larger ones.


1. St32x32bOp with 16-bit dtype → ILLEGAL MEMORY ACCESS

St32x32bOp(Repetition(N), BFloat16) crashes at runtime. You MUST use St32x32bOp(Repetition(N), Float32) and pack 2×16-bit values into 1×Float32 backing words via cute.recast_ptr. The 16-bit type only appears in the recast view, never in the store atom itself.

2. V B-Operand Major Mode ≠ K Major Mode

FMHA requires v_major_mode == OperandMajorMode.MN. Passing K's K-major mode for V is WRONG. V must be shaped (head_dim, seq) with strides (1, head_dim) to produce MN-major. Standard PyTorch row-major (seq, head_dim) gives K-major.

3. CuTe Nested Layout Modes Flatten Sequentially

A layout like ((128,16),1,(4,2)):((65536,1),0,(16,64)) looks "non-sequential" but flattens to addr = m*65536 + k when k = k0 + 16k1 + 64k2 (CuTe row-major order). Do NOT assume nested modes imply non-sequential physical addressing. The C-fragment composition and A-fragment alias the same TMEM columns — BUT ONLY WHEN N_MMA MATCHES (i.e., (128,128) PV). For (128,64) PV, N_MMA=64 and the alias breaks.

4. PipelineUmmaAsync Consumer Group = Thread Count, NOT Warp Count

# WRONG: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4)
# CORRECT: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(warp_ids))

5. PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk

Passing cta_layout_vmnk to the mma_si PipelineUmmaAsync causes deadlock. FMHA does not pass it. Remove it.

6. TMA Warp Must NOT Call tmem.wait_for_alloc()

The tmem allocation barrier only includes MMA + epilogue warps. The TMA warp is excluded. Calling wait_for_alloc() from the TMA warp corrupts the barrier.

7. PV MMA ACCUMULATE Must Be False on First Tile

If ACCUMULATE=True on the first PV MMA, O = P@V + old_O adds uninitialized TMEM to the result. Always set ACCUMULATE=False for the first PV, then True for subsequent tiles. FMHA: pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0).

8. TMEM Pointer Arithmetic: Offset Units Depend on Pointer Type

When computing PV A-fragment offset from the softmax P offset:

# Softmax store: FP32 pointer + tmem_p0_offset (in FP32 elements)
tStS_P = cute.make_tensor(tStS.iterator + tmem_p0_offset, tStS_P_layout)

# PV A-fragment: BF16 pointer + scaled offset (in BF16 elements)
p_offset = acc_dtype.width // q_dtype.width * tmem_p0_offset  # 2 * 32 = 64
tOrP0 = cute.make_tensor(tOrP.iterator + p_offset, tOrP.layout)

Both must address the same physical TMEM column. The 2× scaling accounts for FP32→BF16 element size difference.

9. C-Fragment → A-Fragment TMEM Alias Only Works When N_MMA Matches

The softmax packing writes P using the QK C-fragment layout. The PV A-fragment reads P. These alias correctly ONLY when both MMA atoms have the same N_MMA (i.e., both (128,128,16) → N_MMA=128). When the PV MMA uses (128,64,16) → N_MMA=64, the A-fragment has a different TMEM stride and reads garbage. The softmax packing must be adapted to write P in the PV A-fragment's layout.

10. epi_tile Must Match PV Output Shape, Not QK

compute_epilogue_tile_shape must use PV's cta_tile_shape_mnk, not QK's. Also, self.cta_tile_shape_mnk must be set to PV's cta tile before calling epilogue_tma_store (it reads gemm_kernel.cta_tile_shape_mnk internally). FMHA sets self.epi_tile = self.pv_mma_tiler[:2] directly.

11. GPU State Persists After Deadlocked Kernels

After a kernel deadlock, the GPU may remain in a bad state. Kill all python processes using the GPU before running new tests. nvidia-smi shows hanging processes. Use kill -9 <PID> to clean up. A deadlocked kernel can also cause subsequent runs to fail even if the code is correct.

12. V Padded to (128,128) Must Use MN-Major Strides

When padding V from (64,128) to (128,128), the padded tensor MUST use MN-major strides (1, 128), not the default PyTorch row-major strides (128, 1). Use as_strided or transpose() to get the correct layout. The wrong layout causes LayoutEnum.ROW_MAJOR instead of LayoutEnum.COL_MAJOR, which the PV MMA's OperandMajorMode.MN does not expect.


Architecture: Per-Tile Flow

For each KV tile:
  1. Load warp writes sKV[stage] (paged FP8 gather via indexed cp.async)
  2. MMA warp issues MMA1: sQ @ sKV[stage]^T → tmem_scores (accumulate=False)
     Signals scores_full_mbar (via PipelineUmmaAsync commit)
  3. Epilogue warps wait on mma_si consumer (scores ready), then:
     a. tcgen05.ld scores from TMEM → register fragments
     b. Compute tile_max, new_max, rescale = exp(old_max - new_max)
     c. Apply rescale to tmem_output IN PLACE (tmem_output *= rescale)
     d. tcgen05.st exp(scores - new_max) back to TMEM → P operand (via C-fragment composition)
     e. Release mma_si (softmax_done — MMA warp can re-acquire and issue PV MMA)
  4. MMA warp waits on mma_si acquire (softmax done), MMA2: P @ sV → tmem_output (accumulate=True)
  5. Stage released, load warp can refill it

After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast to BF16, store to GMEM

Test Results

File Description Cosine Status
test_stage_a_v2.py Q@K^T only 0.999999 PASS
test_mma_si_only.py Q@K^T + mma_si pipeline (no PV) 0.999999 PASS
test_softmax_only.py Q@K^T + softmax packing, output S 0.52 S overwritten by P (expected)
test_mma_si_pv.py Q@K^T + softmax + P@V (V MN-major, 128x64) 0.01 PV output garbage
test_pv_diag.py Q@K^T + softmax + P@V (V=I 128x128) 1.0 PASS
test_pv_diag.py Q@K^T + softmax + P@V (random V 128x128) 0.999999 PASS
test_diag_v_truncid.py Q@K^T + softmax + P@V (trunc identity 64x128, epi from PV) 0.02 O[m,d]≈P[m,2d] — TMEM alias mismatch
test_diag_v_ones.py All-ones V (64x128) 0.999999 uniform data hides mismatch
test_diag_v_ones.py Single-element V (64x128) 1.0 sparse data hides mismatch
test_diag_layout.py (128,64) PV with epi from PV cta_tile 0.876 partial fix — epi correct, TMEM alias still broken
test_diag_smem_layout.py Print V SMEM layouts for (128,64) vs (128,128) N/A layouts confirmed correct
test_layout_compare.py Print TMEM layouts for QK S and PV A-fragment N/A layout inspection
test_stage_b_v28.py (128,64) PV MMA + epi from PV DEADLOCK in epilogue_tma_store
test_stage_b_v29.py Padded V (128,128) + (128,128) PV MMA DEADLOCK (likely code bug, not fundamental)
test_stage_b_v30.py Copy of test_pv_diag.py 1.0 PASS (sanity check)

Critical APIs & Lessons

TMEM offset arithmetic

  • find_tmem_tensor_col_offset(fragment) — returns physical TMEM column count
  • QK accumulator: 128 TMEM columns
  • A-fragment offset: acc_dtype.width // q_dtype.width * tmem_p0_offset (F32/BF16=2)

pv_mma_tiler — FMHA Convention

pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1])
# = (M, head_dim, QK_N) = (128, 64, 128) for head_dim=64

FMHA passes pv_mma_tiler[:2] = (128, head_dim) to make_trivial_tiled_mma, NOT the QK tiler (128, 128).

FMHA uses (128,16) PV MMA, NOT (128,64)

FMHA uses pv_mma_tiler = (128, 16, 128) with MN = (128, 16). For head_dim=64, FMHA tiles the N dimension 4 times. The softmax writes P once, and each tile reads the same P with a different V slice.

make_trivial_tiled_mma — Use New Overload

make_trivial_tiled_mma(a_dtype, b_dtype, a_leading_mode, b_leading_mode,
                        acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)

3D tensors required

Tensors must be 3D (M, K, L) for cute.local_tile — add L=1 dimension.

Other APIs

  1. cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...) — CuTe tensor from PyTorch
  2. PipelineTmaUmma.create(...).make_participants() — returns (producer, consumer) pair
  3. utils.gemm.sm100.epilogue_tma_store — handles transform + partition/dcopy. DO NOT hand-roll.
  4. smem.allocate_tensor() — for SMEM tensors
  5. LayoutEnum.from_tensor(a).mma_major_mode() — major mode from cute tensor

Environment

  • Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
  • venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
  • PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
  • Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
  • vLLM repo: /root/dsv4-nvfp4-workspace/vllm (modified for Blackwell)
  • Pseudocode: /root/fragile-kernel-example/README.md
  • fmha.py reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
  • fmha_bwd.py reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py