- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
DSV4 NVFP4 Kernel
Status (May 21, 2026 — 15:40 UTC)
| Stage | Status | Description |
|---|---|---|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | 🔨 NEXT | Real softmax: row max, exp, rescale, row sum |
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
| E | TODO | Production kernel: integrate into cutedsl/, PyTorch custom op, vLLM bridge |
Where Things Live
Canonical Test Files (current working code)
| File | Stage | What It Tests |
|---|---|---|
tests/test_fmha_v3.py |
A+B | Full QK→softmax→PV with KV-tile interleaving. This is the canonical Stage B kernel. |
tests/test_pv64_with_softmax.py |
B | Single AB pipeline variant. Simpler, also passing. |
tests/test_128_128_vdiag.py |
A+B | (128,128) PV baseline. Square case. |
tests/test_qkonly.py |
A | QK with split Q/KV pipelines, no softmax, no PV. |
tests/test_qk_softmax.py |
A+partial B | QK + softmax writes P to TMEM. No PV. |
Reference Implementations (non-CuTe)
| File | Description |
|---|---|
cutedsl/blackwell_attention.py |
Pure PyTorch reference: RoPE, KV cache, causal attention, SWA |
cutedsl/csa_attention.py |
CSA/HCA sparse attention reference |
cutedsl/native_swa_decode.py |
SWA decode reference |
CuTeDSL Kernel Modules (production code)
| File | Description |
|---|---|
cutedsl/kernel/moe/ |
NVFP4 MoE GEMM kernels (fused SwiGLU, grouped MM) |
cutedsl/kernel/blockscaled_gemm/ |
Block-scaled GEMM |
cutedsl/nvfp4_linear.py |
NVFP4 linear layer wrapper |
cutedsl/runner.py |
MoE runner |
cutedsl/moe_pipeline.py |
MoE pipeline orchestration |
Obsolete Tests (do not use, can delete)
~100+ test files from debugging stages A and B. Key patterns:
test_stage_b_v1.pythroughtest_stage_b_v30.py— incremental Bug 4b debuggingtest_128_16_*.py— early (128,16) PV attempts with wrong head_dimtest_tmem_*.py,test_bf16_*.py— standalone TMEM/copy debuggingtest_pv64_no_softmax.py,test_pv64_fmha_v.py,test_pv64_kmajor_v.py— Bug 4b root cause isolation
These can be deleted once the canonical tests are stable and the kernel is extracted.
Stage C: Real Softmax
What We Have Now
Identity softmax: load S FP32 from TMEM, convert to BF16, store P back to TMEM. This proves the TMEM pipeline works but isn't a real softmax.
What We Need
FMHA-style online softmax per KV-tile:
For each KV tile:
1. QK → S (FP32 in TMEM)
2. Load S row-max for this tile: tile_max[j] = max(S[j,:])
3. Compute new row max: new_max[j] = max(old_max[j], tile_max[j])
4. Rescale O: O[j,:] *= exp(old_max[j] - new_max[j])
5. Compute P: P[j,i] = exp(S[j,i] - new_max[j])
6. Store P to TMEM (BF16, same C-fragment composition store)
7. Update row sum: row_sum[j] = row_sum[j] * exp(old_max[j] - new_max[j]) + sum(P[j,:])
8. PV: O[j,:] += P[j,:] @ V[i,:]
After all tiles:
9. O[j,:] /= row_sum[j] (final normalization)
Key Challenges
- Row max across tiles — Must track per-row maximum across KV-tiles and rescale O when a new max is found. This is the core of online softmax.
- Row sum accumulation — Must accumulate exp(sum) across tiles with proper rescaling.
- FP32 precision — Row max, rescale, and row sum must stay in FP32 for numerical stability. Only P (the exp values) get cast to BF16 for TMEM store.
- O rescale in TMEM — When a new row max is found, the existing O in TMEM must be multiplied by
exp(old_max - new_max). This requires loading O, rescaling, and storing back. Same TMEM load/store machinery as softmax P. - Final normalization — After all KV-tiles, divide O by row_sum. Can be done as part of the epilogue.
Expected Structure
The softmax epilogue warps will expand significantly:
- Currently: load S → convert BF16 → store P (identity softmax)
- After Stage C: load S → compute tile_max → compare with old_max → rescale O → compute exp → store P → update row_sum
The MMA loop remains the same (QK → softmax → PV per tile). The softmax just does more work between QK completion and PV start.
Stage D: Full Decode Attention
What We Have After Stage C
A working QK → real softmax → PV kernel for a single query sequence against a contiguous KV block. Fixed dimensions (128×128 QK, 128×64 PV), single CTA.
What We Need
- Paged KV cache — KV comes from a paged cache (fp8 or bf8 with per-token inverse scale), not a contiguous tensor. TMA loads must follow page tables.
- Multi-query — Multiple query sequences in flight, each with different KV lengths. Requires grid dimensions > 1, possibly persistent kernel.
- Causal masking — QK must mask future positions. For decode (1 query vs N KVs), this is trivial (no mask needed). For prefill, need a causal mask.
- Variable sequence length — Each CTA handles a different number of KV tiles. The loop bound
n_kv_tilesbecomes dynamic. - Multiple head dimensions — HEAD_DIM=16, 64, 128 all need to work. Currently only HEAD_DIM=64 is tested.
- CSA/HCA sparse attention — For compress_ratio > 1, KV is read from compressed cache instead of full KV cache. Different TMEM layouts, different attention patterns.
Key Question: Do We Need Stage D As A Separate Stage?
Stage D is really about scaling the Stage C kernel, not adding fundamentally new compute. The core pipeline (QK → softmax → PV) doesn't change. What changes is:
- Where the data comes from (paged cache vs contiguous tensor)
- How many CTAs run (grid size)
- Whether we need causal masking
This could be folded into the production kernel directly rather than being a separate test stage.
Stage E: Production Kernel
Goal
Replace cutedsl/blackwell_attention.py (pure PyTorch) with a CuTeDSL kernel that runs on the Blackwell tensor cores.
Steps
-
Extract kernel from
test_fmha_v3.py→cutedsl/kernel/attention/fmha_kernel.py- Class
FmhaKernelwith@cute.jit __call__ - Clean parameter interface: Q, K, V, O tensors + config
- No hardcoded dimensions — all derived from MMA shapes
- Class
-
Add real softmax (Stage C) to the extracted kernel
-
Add paged KV cache support (Stage D)
- Page table TMA or gather-style loads
- Per-sequence KV length tracking
-
Wrap as PyTorch custom op →
cutedsl/custom_ops.pyblackwell_fmha_forward(q, k, v, ...) -> o- Autograd support (or torch.compile integration)
- torch.library custom op registration
-
Integrate with vLLM →
vllm/attention/ops/blackwell_fmha.py- Replace the broken FlashMLA Blackwell path
- Hook into vLLM's paged attention interface
- Support both prefill and decode modes
-
Benchmark and tune
- Profile against PyTorch SDPA baseline
- Tune tile sizes, pipeline stages, SMEM usage
- Verify numerical accuracy vs reference across head dims and sequence lengths
File Structure (target)
cutedsl/
kernel/
attention/
__init__.py
fmha_kernel.py ← extracted, clean CuTeDSL kernel
fmha_softmax.py ← real softmax (Stage C)
fmha_epilogue.py ← row sum normalization, O output
blackwell_attention.py ← PyTorch reference (keep for testing)
custom_ops.py ← PyTorch custom op wrappers
TMEM Layout (Current)
Col: 0 32 64 96 128 192 256
|---- S ----|---- P ----| |---- O ----|
| QK acc | Softmax P | (gap) | PV acc |
| 128 FP32 | 64 FP32 | 32 col | 64 FP32 |
For Stage C, we'll need additional TMEM regions:
row_max— per-row FP32 max (128 rows × 1 col = 128 FP32 values, can use 4 TMEM columns)row_sum— per-row FP32 sum (128 rows × 1 col, 4 TMEM columns)old_max— per-row FP32 previous max (4 TMEM columns)
These are tiny (4-8 TMEM columns each). They can go in the gap at columns 96-128 or after O.
Key Lessons From Stages A & B
-
NEVER use
find_tmem_tensor_col_offset()as a TMEM placement decision. It returns footprint size, not a safe column offset. The P/O overlap bug cost the entire Bug 4b debugging session. -
FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe. The DLPack shape (n, hd) has wrong logical modes for PV B-operand.
-
TMEM allocation must be power of 2.
TmemAllocator.allocate()asserts this. -
P/A alias works. QK C-fragment composition store + PV A-fragment read alias the same physical TMEM columns. Proven for (128,64) and (128,128).
-
Square hides bugs. (128,128) PV worked for every wrong approach because both dims are 128. Always test non-square cases.
-
St32x32bOpMUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access. -
First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.