biondizzle 467ade37b2 Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
2026-05-21 00:12:47 +00:00
2026-05-19 09:37:38 +00:00

DeepSeek-V4 NVFP4 Kernel Suite

CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use cutlass.cute (CuTeDSL) with Blackwell tensor cores.

File Map

cutedsl/
├── native_swa_decode.py         # SWA decode attention — IN PROGRESS (v3 tcgen05 rewrite)
├── native_sparse_decode.py      # Sparse (CSA/HCA) decode — NOT YET REWRITTEN
├── nvfp4_cutedsl.py             # NVFP4 MoE runner (CuTeDSL) — WORKING
├── moe_pipeline.py              # MoE fused SwiGLU pipeline — WORKING
├── blackwell_attention.py       # vLLM bridge for Blackwell attention path
├── csa_attention.py             # CSA/HCA sparse attention bridge
├── custom_ops.py                # Custom CUDA ops registration
└── kernel/
    └── blockscaled_gemm/
        └── dense_blockscaled_gemm_persistent.py  # REFERENCE: Blackwell TMEM/tcgen05 GEMM

tests/
├── test_stage_a_v2.py                # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM
├── test_stage_b_v7.py                # 🔨 Stage B: two MMAs + C-fragment softmax (runs, wrong output)
├── test_stage_b_afrag2.py            # 🔨 Stage B: A-fragment store pattern (compiles, wrong output)
├── test_tmem_pure_fp32.py            # ✅ FP32 ld→st roundtrip on C-fragment: cosine 0.999999
├── test_bf16_elemwise.py             # ✅ FP32→BF16→FP32 elemwise + FP32 st: cosine 0.999999
├── test_recast_minimal.py            # ✅ BF16 recast ld S0→st S1 via C-fragment: cosine 0.999999
├── test_bf16_recast_simple.py        # ❌ BF16 recast ld/st same region (S0): zero (can't overwrite MMA output)
├── test_tmem_copy_roundtrip.py       # ❌ BF16 recast + C→A mismatch: zero
├── test_stage_b_final.py             # ❌ C-fragment st + A-fragment read: NaN (physical layout mismatch)
├── test_afrag_roundtrip.py           # ❌ A-frag st corrupts S0 (overlapping TMEM region)
├── diag_tmem.py                      # Diagnostic: TMEM layout inspection
└── ...

Current Status

Stage A: Bare Q@K^T via tcgen05.mma — COMPLETE (May 20)

File: tests/test_stage_a_v2.py Result: Q(128,128) @ K^T(128,128) → S(128,128), cosine 0.999999

Validates the full tcgen05.mma → TMEM → epilogue → GMEM path:

  • tcgen05.mma with BF16 inputs, FP32 TMEM accumulator
  • TMA load for A and B (cute.nvgpu.make_tiled_tma_atom_A/B)
  • TMA store for C (cpasync.CopyBulkTensorTileS2GOp)
  • Warp specialization: 4 epilogue warps + 1 MMA warp + 1 TMA warp = 192 threads
  • PipelineTmaUmma for AB pipeline, PipelineUmmaAsync for acc pipeline
  • TmemAllocator for TMEM allocation/deallocation
  • utils.gemm.sm100.epilogue_tma_store for the TMEM→reg→SMEM→TMA→GMEM epilogue

🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20-21)

Core Problem: The C-fragment (MMA accumulator) and A-fragment (MMA A-operand from TMEM) use different physical TMEM address mappings for the same logical (M,K) position. The softmax writes P via one mapping, but the PV MMA reads via the other. This produces garbage.

What's Been Proven

Test Pattern Result Why
test_tmem_pure_fp32 FP32 ld→st, same C-fragment layout cos=0.999999 C-fragment addresses self-consistent
test_bf16_elemwise FP32→BF16→FP32 elemwise, C-fragment st cos=0.999999 BF16 conversion works, C-fragment st works
test_recast_minimal BF16 recast ld S0→st S1, C-fragment cos=0.999999 Recast works when writing to different region
test_bf16_recast_simple BF16 recast ld/st same region S0 zero Can't overwrite MMA output in same region
test_stage_b_final C-fragment st → A-fragment read (S1) NaN C-layout ≠ A-layout physical addresses
test_stage_b_afrag2 A-fragment st (backward FMHA pattern) cos=-0.02 Store + PV MMA layout compatible, but register data flow wrong

Root Cause: C-fragment vs A-fragment Physical TMEM Layout

From the CUTLASS source (mma_traits_sm100.hpp):

C-fragment (MMA accumulator, FP32):

  • Layout: ((128,128),1,1):((65536,1),0,0)virtual layout
  • Physical TMEM addresses determined by the MMA hardware's accumulator write path
  • St32x32bOp with C-fragment layout writes to C-fragment physical addresses

A-fragment (MMA A-operand from TMEM, BF16, K-major, M=128):

  • Layout: ((128,16),1,4):((65536,1),0,16)physical TMEM layout
  • A[m, k_inner] → tmem[dp=m, col=base + 16*mma_k + k_inner]
  • BK=64 = 4 K=16 MMA atoms, NOT one K=64 atom
  • The 4D fragment partition order is NOT the physical TMEM order

The St32x32bOp with C-fragment composition writes to C-layout physical addresses. The PV MMA reads from A-layout physical addresses. These are different physical locations.

Forward FMHA's Approach (FP16 Only!)

Forward FMHA uses a recast pattern to pack 2×FP16 into 1×FP32 register, then St32x32bOp writes to a C-fragment composition subview. But forward FMHA explicitly rejects BF16:

if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}:
    raise ValueError(in_dtype must be Float8E4M3FN or Float16)

The recast softmax path is validated for FP16, NOT BF16. Our BF16 use is outside the tested path.

Backward FMHA's Approach (BF16 Supported)

Backward FMHA writes dV to TMEM using the A-fragment layout:

  1. tdVrP_iter = cute.recast_ptr(tSTtST.iterator, dtype=self.element_dtype) — recast C-fragment iterator to BF16
  2. tdVrP = cute.make_tensor(tdVrP_iter, tOrP.layout) — A-fragment layout, C-fragment base
  3. tmem_store_atom = cute.make_copy_atom(St32x32bOp(Repetition(8)), self.element_dtype) — BF16 store atom
  4. Quantize via make_rmem_tensor(input.shape, element_dtype) + .load()/.store(v.to(element_dtype)) — true BF16 register, NOT recast
  5. Reshape: cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape)) — match store partition shape

This compiles and runs for us (no crash), but the output is still wrong (cosine -0.02). The remaining issue is the register layout mismatch:

  • Load partition (C-fragment): 128 FP32 values per thread (full 128×128 QK tile)
  • Store partition (A-fragment): 64 BF16 values per thread (128×64 P tile for PV MMA K=64)
  • The backward FMHA uses quantize() + reshape, but our element counts differ because the QK tile is 128×128 while P only needs 128×64

Next Steps for Stage B

  1. Fix the register data flow — properly subselect the P-relevant 64 BF16 columns from the 128 FP32 load columns, or use the backward FMHA's PdO MMA tiler (M=128, N=64) instead of (M=128, N=128)
  2. Verify A-fragment store roundtrip — write known BF16 values via A-fragment store, have PV MMA read them back via A-fragment, confirm the physical TMEM addresses match
  3. Once data flow is correct, add online softmax (Stage C)

🔨 Stage C: Online Softmax — AFTER B

The hard part. Per the pseudocode:

  • Epilogue warps tcgen05.ld scores from TMEM into register fragments
  • Compute per-row: tile_max, new_max, rescale = exp(old_max - new_max)
  • Apply rescale to tmem_output in place (tmem_output *= rescale)
  • Compute exp(scores - new_max), tcgen05.st back to TMEM as P operand for MMA2
  • Update row_sum = row_sum * rescale + new_tile_sum

The register fragment layout from tcgen05.ld is NOT (row, col). It's determined by the MMA instruction's partition of the accumulator. Need to figure out the mapping from fragment indices to logical (head, kv_pos) positions for per-row softmax operations. fmha.py uses tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) for the row max — a built-in reduction that handles the layout.

🔨 Stage D: FP8 Paged KV Gather — AFTER C

Replace BF16 TMA load of KV with:

  • Indexed cp.async gather from paged KV cache (fp8)
  • Per-position dequant scale (inv_scale) applied during or after gather
  • Keep KV in fp8 in SMEM, let the MMA's per-row scale handle dequant (like blockscaled GEMM)

Architecture: Per-Tile Flow (from /root/fragile-kernel-example/README.md)

For each KV tile:
  1. Load warp writes sKV[stage] (paged FP8 gather via indexed cp.async)
  2. MMA warp issues MMA1: sQ @ sKV[stage]^T → tmem_scores (accumulate=False)
     Signals scores_full_mbar (via PipelineUmmaAsync commit)
  3. Epilogue warps wait on mma_si consumer (scores ready), then:
     a. tcgen05.ld scores from TMEM → register fragments
     b. Compute tile_max, new_max, rescale = exp(old_max - new_max)
     c. Apply rescale to tmem_output IN PLACE (tmem_output *= rescale)
     d. tcgen05.st exp(scores - new_max) back to TMEM → now it's the P operand
     e. Release mma_si (softmax_done — MMA warp can re-acquire and issue PV MMA)
  4. MMA warp waits on mma_si acquire (softmax done), then MMA2: P @ sKV[stage] → tmem_output (accumulate=True)
  5. Stage released, load warp can refill it

After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast to BF16, store to GMEM

NVFP4 MoE (CuTeDSL) — WORKING

  • nvfp4_cutedsl.py + moe_pipeline.py
  • CuTeDSL NVFP4 Linear (q_a, kv, q_b, o_b) — cosine 0.994+
  • CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) — cosine 0.988
  • Fused SwiGLU epilogue (granularity-8 weight interleave) — cosine 0.988

FP8 KV Quantize/Dequant — WORKING

  • FP8 KV: cosine 0.9997
  • NVFP4 KV: cosine 0.9943 (2x smaller than FP8)
  • Paged KV cache read/write: cosine 1.0

Sparse Decode Attention — NOT YET REWRITTEN

native_sparse_decode.py still has the scalar FMA bug. Needs the same tcgen05.mma rewrite.

Full Attention Pipeline (standalone tests) — WORKING

  • FP8 KV → full attention: cosine 0.9997
  • CSA sparse attention (cr=4): works
  • HCA sparse attention (cr=128): works
  • Merged CSA+SWA attention: works

Critical APIs & Lessons

C-fragment ≠ A-fragment TMEM Physical Layout — THE MAY 20-21 FINDING

The St32x32bOp with C-fragment composition writes to C-layout physical TMEM addresses. The PV MMA reads from A-layout physical TMEM addresses. These are DIFFERENT physical locations for the same logical (M,K) position.

For the softmax to work, P must be written to TMEM using the A-fragment's physical layout, not the C-fragment's. The backward FMHA does this correctly by:

  1. Creating the store destination with A-fragment layout + recast C-fragment iterator
  2. Using a BF16 St32x32bOp atom
  3. True BF16 register (not FP32 recast) via quantize() pattern

Forward FMHA Recast Pattern — FP16 ONLY

The cute.recast_ptr + .store(v.to(FP16)) pattern for packing 2×16-bit into 1×FP32 register is validated for FP16 only. BF16 is rejected in forward FMHA. The BF16 recast produces zero output when writing to the same TMEM region as the MMA output, and NaN when writing to a different region read via A-fragment.

PipelineUmmaAsync consumer group size — thread count, NOT warp count

# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4)  # warp count

# CORRECT (matches fmha.py):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(softmax_warp_ids))  # thread count

TMEM offset arithmetic

  • find_tmem_tensor_col_offset(fragment) — returns physical TMEM column count (with 0x8000 tag for A-fragments)
  • QK accumulator C fragment: 128 TMEM columns
  • PV A-fragment: offset 0x8020 = tag(0x8000) + col(32) — the 0x8000 is a TMEM memory-space identifier
  • tOrP0 = cute.make_tensor(tOrP.iterator + acc_dtype.width // q_dtype.width * tmem_p0_offset, tOrP.layout) — A-fragment offset scaled by dtype width ratio (F32/BF16 = 2)

A-fragment iterator must use recast C-fragment pointer

When creating the P tensor for PV MMA's A-operand, the iterator must be the C-fragment's iterator recast to BF16:

tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]

Without the recast, the A-fragment addresses are computed from an FP32 pointer base, giving wrong physical TMEM addresses (illegal memory access crash).

V SMEM aliasing (K and V share SMEM)

v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
tCrV = pv_mma.make_fragment_B(sV)

make_trivial_tiled_mma has two overloads

# New (preferred):
make_trivial_tiled_mma(a_dtype, b_dtype, a_leading_mode, b_leading_mode,
                        acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)

# Deprecated (still works, used by Stage A):
make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode,
                        acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)

Other APIs discovered from Stage A

  1. cute.Tensor APIcutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)
  2. 3D tensors — Tensors must be 3D (M, K, L) for cute.local_tile — add L=1 dimension
  3. PipelineTmaUmma.create(...).make_participants() — returns (producer, consumer) pair
  4. utils.gemm.sm100.epilogue_tma_store — handles transform + partition/dcopy. DO NOT hand-roll.
  5. get_num_tmem_alloc_cols — correct TMEM allocation (accepts list of fragments, sums cols, rounds to power of 2)
  6. smem.allocate_tensor() — for SMEM tensors (not SharedStorage struct for A/B/C)
  7. LayoutEnum.from_tensor(a).mma_major_mode() — major mode from cute tensor
  8. Minimum valid N tile for tcgen05.mma BF16: 32 (step 32, range 32-256)

Environment

  • Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
  • venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
  • PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
  • Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
  • vLLM repo: /root/dsv4-nvfp4-workspace/vllm (modified for Blackwell)
  • Pseudocode: /root/fragile-kernel-example/README.md — authoritative per-tile attention flow
  • fmha.py reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
  • fmha_bwd.py reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py

4-Stage Build Plan

Stage Goal Status
A Bare Q@K^T via tcgen05.mma → TMEM → GMEM COMPLETE
B Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) 🔨 A-fragment store compiles, register data flow needs fixing
C Online softmax between MMA1 and MMA2 (the hard part) TODO
D FP8 paged KV gather + dequant (replace BF16 TMA load) TODO
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%