Files
nvfp4-megamoe-kernel/STAGE_D1.3.md
biondizzle 98e5b48470 Update all .md files with D5a/D5b progress, tOrP0 fix, LSE formula
- README.md: Updated Stage status table (D1 🟡, D5 🟢), D5 section with
  D5a/D5b results, tOrP0 bug fix docs, new CuTeDSL constraints #11-12
- STAGE_D1.3.md: Added progress update - TMEM-P works, SMEM-P still blocked,
  recommended next steps
- STAGE_D.md was already updated
2026-05-23 22:07:53 +00:00

8.7 KiB
Raw Blame History

STAGE_D1.3.md — SMEM-P Manual Addressing Notes

Problem Statement

Copy P (attention probabilities) from TMEM to SMEM with layout conversion:

  • Source: TMEM layout from QK C-fragment
  • Destination: SMEM layout for PV A-operand
  • Both represent same 128×128 P matrix, different tiling

Layouts Observed (from debug prints)

QK C-fragment Layout (TMEM)

tStS0 layout: ((128,128),1,1):((65536,1),0,0)
tStS0 shape: ((128, 128), 1, 1)

Interpretation: Shape (128,128) with stride (65536, 1) — column-major? Stride between rows = 65536, between columns = 1.

PV A-operand SMEM Layout

sP layout: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0)
sP shape: ((128, 16), 1, (4, 2), 1)

Interpretation: Complex tiling with modes: (128,16), 1, (4,2), 1

Failed Approaches

1. make_tiled_copy_C (Helpers are a trap)

  • Created tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
  • partition_S and partition_D produce incompatible tensors
  • Source partition: 65536 elements, rank 4
  • Destination partition: 2097152 elements, rank 5 (32× mismatch)
  • Confirmed: helpers assume compatible tiling, but layouts are incompatible

2. Manual Copy Attempts

Attempt 1: Get thread coordinates with cute.coord(tStS0)AttributeError: module 'cutlass.cute' has no attribute 'coord'

Attempt 2: Simple test pattern with dynamic indexing → DSLRuntimeError: object cannot be interpreted as an integer

  • CuTeDSL JIT requires compile-time constants or vectorized loops
  • Dynamic SMEM offset computation not trivial

Key Unknowns Blocking Progress

  1. Coordinate Systems: How to get thread's logical position in QK C-fragment partition?
  2. Mapping Formula: Given QK C-fragment coordinate (i,j), what is corresponding PV A-operand coordinate?
  3. Offset Computation: How to compute SMEM offset from logical coordinate given a layout?
  4. Manual Pattern: What's the CuTeDSL pattern for manual layout conversion?

Hypotheses

Hypothesis A: Layouts represent same 128×128 data

  • QK C-fragment: tiles for QK MMA (C-fragment partitioning)
  • PV A-operand: tiles for PV MMA (A-operand partitioning)
  • Need coordinate transformation: QK tile → PV tile mapping

Hypothesis B: tStS0 is whole 128×128 matrix view

  • Not thread's slice but entire matrix in TMEM
  • Each thread accesses portion via its thread coordinates
  • Need thread's position within QK C-fragment partition

Hypothesis C: Manual transpose required

  • Implement 128×128 transpose between two tiling patterns
  • Could brute-force with nested loops over 128×128
  • But thread owns only subset (partitioned by QK MMA)

Questions for CUTLASS LLM

  1. Thread Coordinates: How to get thread's logical coordinates within a layout partition?

    • cute.coord() doesn't exist
    • Need: thread position in QK C-fragment partition
    • Possibly: cute.arch.thread_idx(), cute.arch.lane_idx(), or from MMA?
  2. Layout Mapping: Given source layout L1 and destination layout L2 for same logical space (128×128 matrix), how to compute coordinate transform?

    • L1: ((128,128),1,1):((65536,1),0,0) (QK C-fragment)
    • L2: ((128,16),1,(4,2),1):((64,1),0,(16,8192),0) (PV A-operand)
    • Need mapping: L1 coordinate → L2 coordinate
  3. Offset API: Is there cute.offset(layout, coord)? How to compute physical offset from logical coordinate given a layout?

  4. Manual Pattern: Example of manual copy between incompatible layouts in CuTeDSL?

    • Source: tensor with layout L1
    • Destination: tensor with layout L2
    • Manual element-wise copy with address computation
  5. MMA Partitioning: How does QK MMA partition the 128×128 C-fragment among threads?

    • Each thread owns subset of elements
    • Need to know which elements thread owns

Found Clue in vendored Code

In moe_torch_scaled_grouped_mm.py:

bidx, bidy, bidz = cute.arch.block_idx()
mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape)
  • Uses block index to compute MMA tile coordinate
  • tiled_mma.thr_id.shape gives thread partition shape
  • But we need thread coordinate within warp/CTA, not block

Current Code State

In fmha.py, SMEM-P section is:

else:
    # SMEM-P: Manual addressing (helpers are a trap)
    print(f"[SMEM-P MANUAL] Starting manual P write to SMEM")
    print(f"[SMEM-P MANUAL] tStS0 layout: {tStS0.layout}")
    print(f"[SMEM-P MANUAL] sP layout: {sP.layout}")
    
    # TODO: Implement proper mapping
    # For now, zero sP (wrong but compiles)
    for j in cutlass.range(cute.size(sP), vectorize=True):
        sP[j] = BFloat16(0.0)

Test Status

  • test_fmha_v3_stage_d1.py fails with cosine NaN for hd=64, -0.047 for hd=256, 0.011 for hd=512
  • Zeroed sP causes PV MMA to read zeros → garbage output
  • TMEM-P path works for hd=64 but fails for hd>64 due to TMEM layout mismatch

Progress Update (2026-05-23 19:35 UTC)

CUTLASS LLM responded! Got complete solution:

Key Solution:

  1. Use make_identity_tensor(tStS0.shape) for coordinate tensor
  2. Partition coordinate tensor same way as data tensor
  3. Mapping formula: QK ((m, n), 0, 0) → PV ((m, n % 16), 0, ((n // 16) % 4, n // 64), 0)
  4. Use tensor indexing sP[dst_coord] = value, not manual offsets

Implementation Progress:

  • Coordinate mapping function implemented and works (qk_to_pv_coord)
  • Tensor indexing with coordinate works (sP[test_coord] = value)
  • Need to implement full 128-value mapping per thread
  • Need to get QK coordinates for each of thread's 128 P values

Implementation Status (2026-05-23 19:55 UTC)

Implemented full SMEM-P with coordinate mapping

  • Created coordinate tensor cS (already existed for row_max)
  • Partitioned as tTMEM_LOADcS_frg matching P value fragments
  • In softmax loop, for each (k,j):
    • Get QK coordinate (m,n) from tTMEM_LOADcS_frg[k,j]
    • Map to PV SMEM coordinate using formula
    • Write P value (or test pattern) to sP[pv_coord]

Result: Cosine ~0.02 (near zero correlation)

  • Kernel compiles and runs
  • PV reads SOMETHING from SMEM (output non-zero)
  • But mapping appears wrong (random correlation)
  • Output scaling huge (280k vs reference 0.2)

Possible Issues:

  1. Coordinate mapping formula wrong — PV A-operand layout might differ
  2. SMEM swizzle mismatch — tensor indexing might not handle swizzle correctly
  3. Thread collisions — Multiple threads writing same SMEM location
  4. P value normalization — Unnormalized P values cause scaling issues

Debug Attempts:

  1. Test pattern (k+j)*0.01 → cosine 0.02
  2. Linear index m*128+n → cosine 0.006 (huge output as expected)
  3. Both show mapping is bijective but wrong locations

Next Actions:

  1. Verify coordinate mapping by computing SMEM offset manually
  2. Check if PV expects transposed P matrix
  3. Examine PV MMA tiler and SMEM layout generation
  4. Consider alternative: fix TMEM layout generation instead

Progress Update (2026-05-23 21:30 UTC)

TMEM-P path now works at hd=64 (cos 0.973). The root cause of NaN/zeros was a missing TMEM column offset on tOrP0 — PV MMA was reading from column 0 (where S is) instead of column 32 (where P is stored by softmax warps). Fixed with const_expr conditional.

SMEM-P remains unsolved. The make_tiled_copy_C approach gives rank mismatch. Manual coordinate mapping compiles but produces near-zero cosine (wrong addresses). The CUTLASS reference FMHA uses TMEM-P exclusively (12-warp layout with more TMEM budget). For our 6-warp layout, SMEM-P is needed for hd>64.

Current D1 status:

  • hd=64 (TMEM-P): cos 0.973
  • hd=256 (SMEM-P stub): FAIL (zeros)
  • hd=512 (SMEM-P stub): FAIL (zeros)

Workaround for hd>64: The D5b milestone (Python SWA+sink merge) works at hd=64. SMEM-P for hd>64 is a production optimization, not a correctness blocker. The full DSV4 pipeline (CSA + HCA + SWA) can be tested at hd=64 with TMEM-P.

Key discoveries:

  1. make_tiled_copy_C(store_atom, qk_mma) creates a copy that partitions threads by QK C-fragment layout, but the source and destination have incompatible ranks (4 vs 3). This is a fundamental layout incompatibility.
  2. Manual coordinate mapping (qk_to_pv_coord) compiles and runs but produces wrong results. The mapping formula may be incorrect, or SMEM swizzle may interfere with tensor indexing.
  3. The CUTLASS reference FMHA (12-warp) avoids SMEM-P entirely by using TMEM-P with more warps (more TMEM budget). A 12-warp layout would solve the SMEM-P problem architecturally.

Recommended next steps for SMEM-P:

  1. Try 12-warp layout (like CUTLASS reference) to avoid SMEM-P entirely
  2. OR: Use make_tiled_copy with pv_mma (not qk_mma) for the copy, since PV MMA knows the SMEM layout
  3. OR: Implement a two-stage copy: QK C-fragment → intermediate buffer → PV A-operand SMEM