Files
nvfp4-megamoe-kernel/SMEM_P_GUIDANCE_REQUEST.md

6.4 KiB
Raw Blame History

SMEM-P Copy Problem — CUTLASS CuTeDSL Guidance Request

Summary

I need to write a register→SMEM copy that writes P (attention weight) values from softmax warp registers to a swizzled SMEM buffer (sP), so the MMA warp can read P as the A-operand of a PV GEMM. The source register layout comes from a TMEM load partition (Ld32x32bOp), and the destination is a swizzled SMEM layout from make_smem_layout_a for the PV MMA's A-operand. I cannot find a make_tiled_copy* variant that correctly handles both the source and destination layouts simultaneously.

Architecture

  • 6-warp FMHA kernel on Blackwell (SM100)
    • Warps 0-3: Softmax + Epilogue (128 threads)
    • Warp 4: MMA (32 threads)
    • Warp 5: TMA (32 threads)
  • QK GEMM: (128, 128) @ (128, 128) → S(128, 128) — A,B from SMEM, C to TMEM
  • PV GEMM: P(128, 128) @ V(128, N) → O(128, N) — for hd>64, P from SMEM (A-operand), V from SMEM (B-operand), C to TMEM

The Data Flow

  1. QK GEMM writes S to TMEM (C-fragment layout)
  2. Softmax warps read S from TMEM via Ld32x32bOp(Repetition(32)) into registers
  3. Softmax computes P = exp2(S * scale_log2 - row_max) in registers
  4. [THIS IS THE PROBLEM] Softmax warps write P to sP (SMEM, swizzled, PV A-operand layout)
  5. MMA warp reads P from sP via pv_mma.make_fragment_A(sP) (standard SMEM A-operand load)
  6. MMA warp executes P @ V → O GEMM

Layout Details (from diagnostics on B200)

QK C-fragment (TMEM accumulator):

tStS shape: ((128, 128), 1, 1)
tStS layout: ((128,128),1,1):((65536,1),0,0)

Softmax TMEM load partition (what each thread owns):

tTMEM_LOADtS shape: (((32, 32), 1), 4, 1, 1)  -- source partition
tTMEM_LOADcS shape: ((32, 1), 4, 1, 1)          -- coordinate partition
tTMEM_LOADcS layout: ((32,1),4,1,1):((1@1,0),32@1,0,0)

Each softmax thread has 128 S/P coordinates (32 × 4 fragments). Total: 128 threads × 128 elements = 16384 = 128×128 ✓

sP (PV A-operand SMEM, swizzled):

sP shape: ((128, 16), 1, (4, 2), 1)
sP_2d shape: (((128, 16), 1, (4, 2)), 1)  rank=2
sP_2d layout: (((128,16),1,(4,2)),1):(((64,1),0,(16,8192)),0)
Swizzle: S<3,4,3>

Total elements: 128 × 16 × 4 × 2 = 16384 = 128×128 ✓

PV A-operand fragments (what MMA warp reads from sP):

For SMEM-P, the MMA warp uses:

  • tCrP = pv_mma.make_fragment_A(sP) — SMEM load fragment
  • tOrP = pv_thr.make_fragment_A(sP)[(None,None,None,0)] — TMEM/SMEM read fragment

What I've Tried

Attempt 1: make_tiled_copy_C(copy_atom, qk_mma)

This creates a copy where the thread partition follows the QK MMA's C-fragment thread mapping.

Result: Source and destination partitions both have rank 3 (matching!):

partition_S(qk_C_2d) shape: ((8, (16, 128)), 128, 1)  rank=3
partition_D(sP_2d)   shape: ((8, (16, 128)), (64, 2), 1)  rank=3

Problem: The CopyUniversalOp with 128-bit vectorization fails:

'cute.copy' op cannot vectorize copy to 8 elements (static strides must be 1)
Source strides: ((1,(8,128)),(16384,0))  -- NOT contiguous
Dest strides:   ((64,(512,0)),((1,8192),0))

The source tensor has stride 8 in one dimension, meaning elements are not contiguous. The CopyUniversalOp expects stride-1 contiguity for vectorized copies.

Root cause: The QK C-fragment register layout (from make_fragment_C) has non-contiguous strides when viewed as a register tensor. The make_tiled_copy_C creates the partition based on the MMA's C-fragment thread mapping, which works for the destination (sP with swizzle) but the source partition inherits non-contiguous strides.

Key insight: The softmax threads have P values in the TMEM load register layout (rP_bf16 shares layout with tTMEM_LOADrS), NOT the QK C-fragment register layout. These are different layouts with different stride patterns.

Attempt 2: get_smem_store_op + make_tiled_copy_D

This uses the CUTLASS blackwell_helpers pattern from the FMHA correction epilog:

smem_store_atom = get_smem_store_op(c_layout, c_dtype, acc_dtype, tiled_tmem_load)
tiled_smem_store = cute.make_tiled_copy_D(smem_store_atom, tiled_tmem_load)

Problem: Size mismatch in mode-1:

partition_S(tTMEM_LOADtS) → source has 4 in mode-1
partition_D(sP)            → destination has 8 in mode-1

Root cause: The TMEM load partition and the sP SMEM layout have different tiling in the K-dimension. The tTMEM_LOADtS has shape (((32, 32), 1), 4, 1, 1) (4 fragments), while sP has shape ((128, 16), 1, (4, 2), 1) (4×2 sub-tiles). The make_tiled_copy_D creates a partition that doesn't match between source (4 fragments) and destination (8 sub-tiles).

The Core Question

How do I create a TiledCopy that:

  1. Source: Uses the TMEM load partition's thread mapping (128 softmax threads, each owning 128 P values in ((32, 1), 4, 1, 1) coordinate layout)
  2. Destination: Writes to sP in the PV A-operand SMEM layout (swizzled with S<3,4,3>)
  3. Copy atom: Handles the swizzled SMEM layout correctly

OR: Is there a completely different approach I'm missing?

The CUTLASS FMHA reference (12-warp layout) uses TMEM-P exclusively and doesn't have this problem. But our 6-warp layout requires SMEM-P for head_dim > 64 because TMEM can't hold S, P, and O simultaneously.

Environment

  • CuTeDSL (Python DSL for CUTLASS) on Blackwell (B200, SM100)
  • cutlass.utils.blackwell_helpers: get_tmem_load_op, get_smem_store_op
  • cutlass.utils.gemm.sm100: epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition, make_smem_layout_a, make_trivial_tiled_mma
  • tcgen05.copy: Ld32x32bOp, St32x32bOp
  • cute.nvgpu.CopyUniversalOp

Additional Context

The P matrix is (128, 128) — same dimensions as the QK attention matrix. The first 128 is the M-dimension (query heads), and the second 128 is the K-dimension of the PV GEMM (which equals n_kv_tokens, the number of key/value positions). For a single KV tile (n_kv=128), the P matrix is the full 128×128.

For head_dim=64 (TMEM-P, working), P is stored in TMEM at column offset tmem_p0_offset=32 using the register bridge pattern (FP32 exp2 values stored as BF16 via recast_ptr). For head_dim>64, P must go to SMEM instead.

The sP SMEM buffer is allocated with smem.allocate_tensor(element_type=BF16, layout=p_smem_s.outer, byte_alignment=128, swizzle=p_smem_s.inner) and has the PV MMA's A-operand SMEM layout, which is designed for pv_mma.make_fragment_A(sP) to read correctly during the PV GEMM.