biondizzle 3fb3c925af Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00
2026-05-19 09:37:38 +00:00

DSV4 NVFP4 Kernel

Status (May 21, 2026 — 15:40 UTC)

Stage Status Description
A COMPLETE Q@K^T via tcgen05.mma → TMEM → GMEM
B COMPLETE QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving)
C 🔨 NEXT Real softmax: row max, exp, rescale, row sum
D TODO Full decode attention: paged KV cache, multi-query, causal mask
E TODO Production kernel: integrate into cutedsl/, PyTorch custom op, vLLM bridge

Where Things Live

Canonical Test Files (current working code)

File Stage What It Tests
tests/test_fmha_v3.py A+B Full QK→softmax→PV with KV-tile interleaving. This is the canonical Stage B kernel.
tests/test_pv64_with_softmax.py B Single AB pipeline variant. Simpler, also passing.
tests/test_128_128_vdiag.py A+B (128,128) PV baseline. Square case.
tests/test_qkonly.py A QK with split Q/KV pipelines, no softmax, no PV.
tests/test_qk_softmax.py A+partial B QK + softmax writes P to TMEM. No PV.

Reference Implementations (non-CuTe)

File Description
cutedsl/blackwell_attention.py Pure PyTorch reference: RoPE, KV cache, causal attention, SWA
cutedsl/csa_attention.py CSA/HCA sparse attention reference
cutedsl/native_swa_decode.py SWA decode reference

CuTeDSL Kernel Modules (production code)

File Description
cutedsl/kernel/moe/ NVFP4 MoE GEMM kernels (fused SwiGLU, grouped MM)
cutedsl/kernel/blockscaled_gemm/ Block-scaled GEMM
cutedsl/nvfp4_linear.py NVFP4 linear layer wrapper
cutedsl/runner.py MoE runner
cutedsl/moe_pipeline.py MoE pipeline orchestration

Obsolete Tests (do not use, can delete)

~100+ test files from debugging stages A and B. Key patterns:

  • test_stage_b_v1.py through test_stage_b_v30.py — incremental Bug 4b debugging
  • test_128_16_*.py — early (128,16) PV attempts with wrong head_dim
  • test_tmem_*.py, test_bf16_*.py — standalone TMEM/copy debugging
  • test_pv64_no_softmax.py, test_pv64_fmha_v.py, test_pv64_kmajor_v.py — Bug 4b root cause isolation

These can be deleted once the canonical tests are stable and the kernel is extracted.


Stage C: Real Softmax

What We Have Now

Identity softmax: load S FP32 from TMEM, convert to BF16, store P back to TMEM. This proves the TMEM pipeline works but isn't a real softmax.

What We Need

FMHA-style online softmax per KV-tile:

For each KV tile:
  1. QK → S (FP32 in TMEM)
  2. Load S row-max for this tile: tile_max[j] = max(S[j,:])
  3. Compute new row max: new_max[j] = max(old_max[j], tile_max[j])
  4. Rescale O: O[j,:] *= exp(old_max[j] - new_max[j])
  5. Compute P: P[j,i] = exp(S[j,i] - new_max[j])
  6. Store P to TMEM (BF16, same C-fragment composition store)
  7. Update row sum: row_sum[j] = row_sum[j] * exp(old_max[j] - new_max[j]) + sum(P[j,:])
  8. PV: O[j,:] += P[j,:] @ V[i,:]
After all tiles:
  9. O[j,:] /= row_sum[j]  (final normalization)

Key Challenges

  1. Row max across tiles — Must track per-row maximum across KV-tiles and rescale O when a new max is found. This is the core of online softmax.
  2. Row sum accumulation — Must accumulate exp(sum) across tiles with proper rescaling.
  3. FP32 precision — Row max, rescale, and row sum must stay in FP32 for numerical stability. Only P (the exp values) get cast to BF16 for TMEM store.
  4. O rescale in TMEM — When a new row max is found, the existing O in TMEM must be multiplied by exp(old_max - new_max). This requires loading O, rescaling, and storing back. Same TMEM load/store machinery as softmax P.
  5. Final normalization — After all KV-tiles, divide O by row_sum. Can be done as part of the epilogue.

Expected Structure

The softmax epilogue warps will expand significantly:

  • Currently: load S → convert BF16 → store P (identity softmax)
  • After Stage C: load S → compute tile_max → compare with old_max → rescale O → compute exp → store P → update row_sum

The MMA loop remains the same (QK → softmax → PV per tile). The softmax just does more work between QK completion and PV start.


Stage D: Full Decode Attention

What We Have After Stage C

A working QK → real softmax → PV kernel for a single query sequence against a contiguous KV block. Fixed dimensions (128×128 QK, 128×64 PV), single CTA.

What We Need

  1. Paged KV cache — KV comes from a paged cache (fp8 or bf8 with per-token inverse scale), not a contiguous tensor. TMA loads must follow page tables.
  2. Multi-query — Multiple query sequences in flight, each with different KV lengths. Requires grid dimensions > 1, possibly persistent kernel.
  3. Causal masking — QK must mask future positions. For decode (1 query vs N KVs), this is trivial (no mask needed). For prefill, need a causal mask.
  4. Variable sequence length — Each CTA handles a different number of KV tiles. The loop bound n_kv_tiles becomes dynamic.
  5. Multiple head dimensions — HEAD_DIM=16, 64, 128 all need to work. Currently only HEAD_DIM=64 is tested.
  6. CSA/HCA sparse attention — For compress_ratio > 1, KV is read from compressed cache instead of full KV cache. Different TMEM layouts, different attention patterns.

Key Question: Do We Need Stage D As A Separate Stage?

Stage D is really about scaling the Stage C kernel, not adding fundamentally new compute. The core pipeline (QK → softmax → PV) doesn't change. What changes is:

  • Where the data comes from (paged cache vs contiguous tensor)
  • How many CTAs run (grid size)
  • Whether we need causal masking

This could be folded into the production kernel directly rather than being a separate test stage.


Stage E: Production Kernel

Goal

Replace cutedsl/blackwell_attention.py (pure PyTorch) with a CuTeDSL kernel that runs on the Blackwell tensor cores.

Steps

  1. Extract kernel from test_fmha_v3.pycutedsl/kernel/attention/fmha_kernel.py

    • Class FmhaKernel with @cute.jit __call__
    • Clean parameter interface: Q, K, V, O tensors + config
    • No hardcoded dimensions — all derived from MMA shapes
  2. Add real softmax (Stage C) to the extracted kernel

  3. Add paged KV cache support (Stage D)

    • Page table TMA or gather-style loads
    • Per-sequence KV length tracking
  4. Wrap as PyTorch custom opcutedsl/custom_ops.py

    • blackwell_fmha_forward(q, k, v, ...) -> o
    • Autograd support (or torch.compile integration)
    • torch.library custom op registration
  5. Integrate with vLLMvllm/attention/ops/blackwell_fmha.py

    • Replace the broken FlashMLA Blackwell path
    • Hook into vLLM's paged attention interface
    • Support both prefill and decode modes
  6. Benchmark and tune

    • Profile against PyTorch SDPA baseline
    • Tune tile sizes, pipeline stages, SMEM usage
    • Verify numerical accuracy vs reference across head dims and sequence lengths

File Structure (target)

cutedsl/
  kernel/
    attention/
      __init__.py
      fmha_kernel.py       ← extracted, clean CuTeDSL kernel
      fmha_softmax.py       ← real softmax (Stage C)
      fmha_epilogue.py      ← row sum normalization, O output
  blackwell_attention.py    ← PyTorch reference (keep for testing)
  custom_ops.py             ← PyTorch custom op wrappers

TMEM Layout (Current)

Col:  0          32          64          96          128         192        256
      |---- S ----|---- P ----|           |---- O ----|
      |  QK acc   | Softmax P |  (gap)    |  PV acc   |
      |  128 FP32 |  64 FP32  |  32 col   |  64 FP32  |

For Stage C, we'll need additional TMEM regions:

  • row_max — per-row FP32 max (128 rows × 1 col = 128 FP32 values, can use 4 TMEM columns)
  • row_sum — per-row FP32 sum (128 rows × 1 col, 4 TMEM columns)
  • old_max — per-row FP32 previous max (4 TMEM columns)

These are tiny (4-8 TMEM columns each). They can go in the gap at columns 96-128 or after O.


Key Lessons From Stages A & B

  1. NEVER use find_tmem_tensor_col_offset() as a TMEM placement decision. It returns footprint size, not a safe column offset. The P/O overlap bug cost the entire Bug 4b debugging session.

  2. FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe. The DLPack shape (n, hd) has wrong logical modes for PV B-operand.

  3. TMEM allocation must be power of 2. TmemAllocator.allocate() asserts this.

  4. P/A alias works. QK C-fragment composition store + PV A-fragment read alias the same physical TMEM columns. Proven for (128,64) and (128,128).

  5. Square hides bugs. (128,128) PV worked for every wrong approach because both dims are 128. Always test non-square cases.

  6. St32x32bOp MUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access.

  7. First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.

Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%