diff --git a/README.md b/README.md index 69180c5a..4c744727 100644 --- a/README.md +++ b/README.md @@ -2,132 +2,105 @@ CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use `cutlass.cute` (CuTeDSL) with Blackwell tensor cores. -## File Map +## Status (May 21, 2026 — 04:10 UTC) -``` -cutedsl/ -├── native_swa_decode.py # SWA decode attention — IN PROGRESS (v3 tcgen05 rewrite) -├── native_sparse_decode.py # Sparse (CSA/HCA) decode — NOT YET REWRITTEN -├── nvfp4_cutedsl.py # NVFP4 MoE runner (CuTeDSL) — WORKING -├── moe_pipeline.py # MoE fused SwiGLU pipeline — WORKING -├── blackwell_attention.py # vLLM bridge for Blackwell attention path -├── csa_attention.py # CSA/HCA sparse attention bridge -├── custom_ops.py # Custom CUDA ops registration -└── kernel/ - └── blockscaled_gemm/ - └── dense_blockscaled_gemm_persistent.py # REFERENCE: Blackwell TMEM/tcgen05 GEMM - -tests/ -├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM -├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + C-fragment softmax (runs, wrong output) -├── test_stage_b_afrag2.py # 🔨 Stage B: A-fragment store pattern (compiles, wrong output) -├── test_tmem_pure_fp32.py # ✅ FP32 ld→st roundtrip on C-fragment: cosine 0.999999 -├── test_bf16_elemwise.py # ✅ FP32→BF16→FP32 elemwise + FP32 st: cosine 0.999999 -├── test_recast_minimal.py # ✅ BF16 recast ld S0→st S1 via C-fragment: cosine 0.999999 -├── test_bf16_recast_simple.py # ❌ BF16 recast ld/st same region (S0): zero (can't overwrite MMA output) -├── test_tmem_copy_roundtrip.py # ❌ BF16 recast + C→A mismatch: zero -├── test_stage_b_final.py # ❌ C-fragment st + A-fragment read: NaN (physical layout mismatch) -├── test_afrag_roundtrip.py # ❌ A-frag st corrupts S0 (overlapping TMEM region) -├── diag_tmem.py # Diagnostic: TMEM layout inspection -└── ... -``` - -## Current Status - -### ✅ Stage A: Bare Q@K^T via tcgen05.mma — COMPLETE (May 20) +### ✅ Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM — COMPLETE **File**: `tests/test_stage_a_v2.py` **Result**: Q(128,128) @ K^T(128,128) → S(128,128), cosine 0.999999 -Validates the full tcgen05.mma → TMEM → epilogue → GMEM path: -- tcgen05.mma with BF16 inputs, FP32 TMEM accumulator -- TMA load for A and B (cute.nvgpu.make_tiled_tma_atom_A/B) -- TMA store for C (cpasync.CopyBulkTensorTileS2GOp) -- Warp specialization: 4 epilogue warps + 1 MMA warp + 1 TMA warp = 192 threads -- PipelineTmaUmma for AB pipeline, PipelineUmmaAsync for acc pipeline -- TmemAllocator for TMEM allocation/deallocation -- utils.gemm.sm100.epilogue_tma_store for the TMEM→reg→SMEM→TMA→GMEM epilogue +### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS -### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20-21) +**Pipeline deadlock: FIXED. Kernel runs without deadlock.** +**Bug 1 (V MN-major): Fix applied.** +**Bug 2 (softmax packing): Fix applied, but PV output is garbage.** -**Core Problem**: The C-fragment (MMA accumulator) and A-fragment (MMA A-operand from TMEM) use **different physical TMEM address mappings** for the same logical (M,K) position. The softmax writes P via one mapping, but the PV MMA reads via the other. This produces garbage. +#### Bug 1: V B-Operand Must Be MN-Major — ✅ FIX APPLIED -#### What's Been Proven +V must be shaped (head_dim, seq) = (64, 128) with strides (1, 64) — MN-major. +PV MMA uses `v_major` (OperandMajorMode.MN) instead of `b_major` (K). -| Test | Pattern | Result | Why | -|------|---------|--------|-----| -| test_tmem_pure_fp32 | FP32 ld→st, same C-fragment layout | ✅ cos=0.999999 | C-fragment addresses self-consistent | -| test_bf16_elemwise | FP32→BF16→FP32 elemwise, C-fragment st | ✅ cos=0.999999 | BF16 conversion works, C-fragment st works | -| test_recast_minimal | BF16 recast ld S0→st S1, C-fragment | ✅ cos=0.999999 | Recast works when writing to different region | -| test_bf16_recast_simple | BF16 recast ld/st same region S0 | ❌ zero | Can't overwrite MMA output in same region | -| test_stage_b_final | C-fragment st → A-fragment read (S1) | ❌ NaN | C-layout ≠ A-layout physical addresses | -| test_stage_b_afrag2 | A-fragment st (backward FMHA pattern) | ❌ cos=-0.02 | Store + PV MMA layout compatible, but register data flow wrong | +V must use `as_strided` — default PyTorch (64,128) gives strides (128,1) which is K-major. -#### Root Cause: C-fragment vs A-fragment Physical TMEM Layout +#### Bug 2 (Packing): C-Fragment Composition Store — ✅ APPLIED, ❌ PV OUTPUT WRONG -From the CUTLASS source (`mma_traits_sm100.hpp`): +FP32→BF16 packing via C-fragment composition store (FMHA pattern) runs without error. +The softmax packing overwrites part of S in TMEM (P at tmem_p0_offset=32 overlaps S at offset 0). +This is intentional — S is no longer needed after softmax. -**C-fragment (MMA accumulator, FP32):** -- Layout: `((128,128),1,1):((65536,1),0,0)` — **virtual** layout -- Physical TMEM addresses determined by the MMA hardware's accumulator write path -- St32x32bOp with C-fragment layout writes to C-fragment physical addresses +⛔ **FOOTGUN**: `St32x32bOp` MUST use Float32, NOT BFloat16. +⚠️ The recast view for P packing uses the LOAD layout (128 BF16 elements), not the store composition shape. -**A-fragment (MMA A-operand from TMEM, BF16, K-major, M=128):** -- Layout: `((128,16),1,4):((65536,1),0,16)` — **physical** TMEM layout -- A[m, k_inner] → `tmem[dp=m, col=base + 16*mma_k + k_inner]` -- BK=64 = 4 K=16 MMA atoms, NOT one K=64 atom -- The 4D fragment partition order is NOT the physical TMEM order +#### Bug 3 (NEW): PV MMA Output Is Garbage — 🔨 INVESTIGATING -**The St32x32bOp with C-fragment composition writes to C-layout physical addresses. The PV MMA reads from A-layout physical addresses. These are different physical locations.** - -#### Forward FMHA's Approach (FP16 Only!) - -Forward FMHA uses a recast pattern to pack 2×FP16 into 1×FP32 register, then St32x32bOp writes to a C-fragment composition subview. **But forward FMHA explicitly rejects BF16:** -```python -if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: - raise ValueError(in_dtype must be Float8E4M3FN or Float16) -``` -The recast softmax path is validated for FP16, NOT BF16. Our BF16 use is outside the tested path. - -#### Backward FMHA's Approach (BF16 Supported) - -Backward FMHA writes dV to TMEM using the A-fragment layout: -1. `tdVrP_iter = cute.recast_ptr(tSTtST.iterator, dtype=self.element_dtype)` — recast C-fragment iterator to BF16 -2. `tdVrP = cute.make_tensor(tdVrP_iter, tOrP.layout)` — A-fragment layout, C-fragment base -3. `tmem_store_atom = cute.make_copy_atom(St32x32bOp(Repetition(8)), self.element_dtype)` — BF16 store atom -4. Quantize via `make_rmem_tensor(input.shape, element_dtype)` + `.load()/.store(v.to(element_dtype))` — true BF16 register, NOT recast -5. Reshape: `cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape))` — match store partition shape - -This compiles and runs for us (no crash), but the output is still wrong (cosine -0.02). The remaining issue is the **register layout mismatch**: -- Load partition (C-fragment): 128 FP32 values per thread (full 128×128 QK tile) -- Store partition (A-fragment): 64 BF16 values per thread (128×64 P tile for PV MMA K=64) -- The backward FMHA uses `quantize()` + reshape, but our element counts differ because the QK tile is 128×128 while P only needs 128×64 - -#### Next Steps for Stage B - -1. **Fix the register data flow** — properly subselect the P-relevant 64 BF16 columns from the 128 FP32 load columns, or use the backward FMHA's PdO MMA tiler (M=128, N=64) instead of (M=128, N=128) -2. **Verify A-fragment store roundtrip** — write known BF16 values via A-fragment store, have PV MMA read them back via A-fragment, confirm the physical TMEM addresses match -3. **Once data flow is correct, add online softmax** (Stage C) +The PV MMA produces cosine ~0.01 against the reference. Suspected cause: TMEM layout mismatch between the softmax P store (C-fragment composition layout) and the PV MMA A-fragment read (`p_tmem_s` layout from `make_smem_layout_a`). These should alias the same physical TMEM columns by the sequential-flattening property, but the specific layout functions may compute different shapes/strides. ### 🔨 Stage C: Online Softmax — AFTER B -The hard part. Per the pseudocode: -- Epilogue warps tcgen05.ld scores from TMEM into register fragments -- Compute per-row: tile_max, new_max, rescale = exp(old_max - new_max) -- Apply rescale to tmem_output in place (tmem_output *= rescale) -- Compute exp(scores - new_max), tcgen05.st back to TMEM as P operand for MMA2 -- Update row_sum = row_sum * rescale + new_tile_sum - -**The register fragment layout from tcgen05.ld is NOT (row, col).** It's determined by the MMA instruction's partition of the accumulator. Need to figure out the mapping from fragment indices to logical (head, kv_pos) positions for per-row softmax operations. fmha.py uses `tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)` for the row max — a built-in reduction that handles the layout. +Per the pseudocode: epilogue warps compute per-row tile_max, rescale, exp, store P back to TMEM. ### 🔨 Stage D: FP8 Paged KV Gather — AFTER C -Replace BF16 TMA load of KV with: -- Indexed cp.async gather from paged KV cache (fp8) -- Per-position dequant scale (inv_scale) applied during or after gather -- Keep KV in fp8 in SMEM, let the MMA's per-row scale handle dequant (like blockscaled GEMM) +Replace BF16 TMA load with FP8 paged KV gather + per-position dequant. -### Architecture: Per-Tile Flow (from /root/fragile-kernel-example/README.md) +--- + +## Pipeline Deadlock — ✅ FIXED (May 21) + +v20-v25 all deadlocked on GPU. Three root causes found and fixed: + +### Fix 1: PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk + +FMHA's mma_s0/mma_s1 PipelineUmmaAsync calls do NOT pass cta_layout_vmnk. Removing it fixes the deadlock. + +### Fix 2: TMA Warp Must NOT Call tmem.wait_for_alloc() + +The tmem allocation barrier has `num_threads = 32 * (mma_warp + epilogue_warps)`. The TMA warp is NOT part of this barrier. Calling `wait_for_alloc()` from the TMA warp corrupts the barrier. + +### Fix 3: PipelineTmaStore (not TmaStorePipeline) + +`pipeline.TmaStorePipeline` does not exist. The correct name is `pipeline.PipelineTmaStore`. + +--- + +## ⛔ DEAD TEST: test_stage_b_v21.py — DELETED, DO NOT RECREATE + +v21 attempted both Bug 1 and Bug 2 fixes in a hand-rolled pipeline kernel. It deadlocks on GPU. Root cause: pipeline synchronization mismatch. **Do not recreate.** Write from scratch using fmha.py as the reference. + +--- + +## ⛔ FOOTGUNS — CUTLASS CuTeDSL Landmines + +### 1. St32x32bOp with 16-bit dtype → ILLEGAL MEMORY ACCESS + +`St32x32bOp(Repetition(N), BFloat16)` crashes at runtime. You MUST use `St32x32bOp(Repetition(N), Float32)` and pack 2×16-bit values into 1×Float32 backing words via `cute.recast_ptr`. The 16-bit type only appears in the recast view, never in the store atom itself. + +### 2. V B-Operand Major Mode ≠ K Major Mode + +FMHA requires `v_major_mode == OperandMajorMode.MN`. Passing K's K-major mode for V is WRONG. V must be shaped (head_dim, seq) with strides (1, head_dim) to produce MN-major. Standard PyTorch row-major (seq, head_dim) gives K-major. + +### 3. CuTe Nested Layout Modes Flatten Sequentially + +A layout like `((128,16),1,(4,2)):((65536,1),0,(16,64))` looks "non-sequential" but flattens to `addr = m*65536 + k` when k = k0 + 16*k1 + 64*k2 (CuTe row-major order). Do NOT assume nested modes imply non-sequential physical addressing. The C-fragment composition and A-fragment alias the same TMEM columns. + +### 4. PipelineUmmaAsync Consumer Group = Thread Count, NOT Warp Count + +```python +# WRONG: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) +# CORRECT: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(warp_ids)) +``` + +### 5. PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk + +Passing `cta_layout_vmnk` to the mma_si PipelineUmmaAsync causes deadlock. FMHA does not pass it. Remove it. + +### 6. TMA Warp Must NOT Call tmem.wait_for_alloc() + +The tmem allocation barrier only includes MMA + epilogue warps. The TMA warp is excluded. Calling `wait_for_alloc()` from the TMA warp corrupts the barrier. + +--- + +## Architecture: Per-Tile Flow ``` For each KV tile: @@ -138,107 +111,57 @@ For each KV tile: a. tcgen05.ld scores from TMEM → register fragments b. Compute tile_max, new_max, rescale = exp(old_max - new_max) c. Apply rescale to tmem_output IN PLACE (tmem_output *= rescale) - d. tcgen05.st exp(scores - new_max) back to TMEM → now it's the P operand + d. tcgen05.st exp(scores - new_max) back to TMEM → P operand (via C-fragment composition) e. Release mma_si (softmax_done — MMA warp can re-acquire and issue PV MMA) - 4. MMA warp waits on mma_si acquire (softmax done), then MMA2: P @ sKV[stage] → tmem_output (accumulate=True) + 4. MMA warp waits on mma_si acquire (softmax done), MMA2: P @ sV → tmem_output (accumulate=True) 5. Stage released, load warp can refill it After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast to BF16, store to GMEM ``` -### ✅ NVFP4 MoE (CuTeDSL) — WORKING -- `nvfp4_cutedsl.py` + `moe_pipeline.py` -- CuTeDSL NVFP4 Linear (q_a, kv, q_b, o_b) — cosine 0.994+ -- CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) — cosine 0.988 -- Fused SwiGLU epilogue (granularity-8 weight interleave) — cosine 0.988 +--- -### ✅ FP8 KV Quantize/Dequant — WORKING -- FP8 KV: cosine 0.9997 -- NVFP4 KV: cosine 0.9943 (2x smaller than FP8) -- Paged KV cache read/write: cosine 1.0 +## Test Results -### ❌ Sparse Decode Attention — NOT YET REWRITTEN -`native_sparse_decode.py` still has the scalar FMA bug. Needs the same tcgen05.mma rewrite. +| File | Description | Cosine | Status | +|------|-------------|--------|--------| +| `test_stage_a_v2.py` | Q@K^T only | 0.999999 | ✅ PASS | +| `test_mma_si_only.py` | Q@K^T + mma_si pipeline (no PV) | 0.999999 | ✅ PASS | +| `test_softmax_only.py` | Q@K^T + softmax packing, output S | 0.52 | ❌ S overwritten by P (expected) | +| `test_mma_si_pv.py` | Q@K^T + softmax + P@V (V MN-major) | 0.01 | ❌ PV output garbage | +| `test_stage_b_v7.py` | Q@K^T + C-fragment softmax (V=K, wrong major) | -0.02 | ❌ wrong major + P packing | +| `test_stage_b_v20.py` | Q@K^T + softmax (V=K, PipelineTmaStore bug) | N/A | ❌ compile error | -### ✅ Full Attention Pipeline (standalone tests) — WORKING -- FP8 KV → full attention: cosine 0.9997 -- CSA sparse attention (cr=4): works -- HCA sparse attention (cr=128): works -- Merged CSA+SWA attention: works +--- ## Critical APIs & Lessons -### C-fragment ≠ A-fragment TMEM Physical Layout — THE MAY 20-21 FINDING - -**The St32x32bOp with C-fragment composition writes to C-layout physical TMEM addresses. The PV MMA reads from A-layout physical TMEM addresses. These are DIFFERENT physical locations for the same logical (M,K) position.** - -For the softmax to work, P must be written to TMEM using the A-fragment's physical layout, not the C-fragment's. The backward FMHA does this correctly by: -1. Creating the store destination with A-fragment layout + recast C-fragment iterator -2. Using a BF16 St32x32bOp atom -3. True BF16 register (not FP32 recast) via quantize() pattern - -### Forward FMHA Recast Pattern — FP16 ONLY - -The `cute.recast_ptr` + `.store(v.to(FP16))` pattern for packing 2×16-bit into 1×FP32 register is validated for FP16 only. BF16 is rejected in forward FMHA. The BF16 recast produces zero output when writing to the same TMEM region as the MMA output, and NaN when writing to a different region read via A-fragment. - -### PipelineUmmaAsync consumer group size — thread count, NOT warp count - -```python -# WRONG (caused CUDA_ERROR_LAUNCH_FAILED): -consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) # warp count - -# CORRECT (matches fmha.py): -consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(softmax_warp_ids)) # thread count -``` - ### TMEM offset arithmetic +- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count +- QK accumulator: 128 TMEM columns +- A-fragment offset: `acc_dtype.width // q_dtype.width * tmem_p0_offset` (F32/BF16=2) -- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count (with 0x8000 tag for A-fragments) -- QK accumulator C fragment: 128 TMEM columns -- PV A-fragment: offset 0x8020 = tag(0x8000) + col(32) — the 0x8000 is a TMEM memory-space identifier -- `tOrP0 = cute.make_tensor(tOrP.iterator + acc_dtype.width // q_dtype.width * tmem_p0_offset, tOrP.layout)` — A-fragment offset scaled by dtype width ratio (F32/BF16 = 2) - -### A-fragment iterator must use recast C-fragment pointer - -When creating the P tensor for PV MMA's A-operand, the iterator must be the C-fragment's iterator recast to BF16: +### pv_mma_tiler — FMHA Convention ```python -tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) -tP = cute.make_tensor(tP_iter, p_tmem_s.outer) -tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] -``` -Without the recast, the A-fragment addresses are computed from an FP32 pointer base, giving wrong physical TMEM addresses (illegal memory access crash). - -### V SMEM aliasing (K and V share SMEM) - -```python -v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1) -sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) -sV = cute.make_tensor(sV_ptr, v_smem_s.outer) -tCrV = pv_mma.make_fragment_B(sV) +pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1]) +# = (M, head_dim, QK_N) = (128, 64, 128) for head_dim=64 ``` -### `make_trivial_tiled_mma` has two overloads - +### make_trivial_tiled_mma — Use New Overload ```python -# New (preferred): make_trivial_tiled_mma(a_dtype, b_dtype, a_leading_mode, b_leading_mode, acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM) - -# Deprecated (still works, used by Stage A): -make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode, - acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM) ``` -### Other APIs discovered from Stage A +### 3D tensors required +Tensors must be 3D (M, K, L) for `cute.local_tile` — add L=1 dimension. -1. **`cute.Tensor` API** — `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)` -2. **3D tensors** — Tensors must be 3D (M, K, L) for `cute.local_tile` — add L=1 dimension -3. **`PipelineTmaUmma.create(...).make_participants()`** — returns `(producer, consumer)` pair -4. **`utils.gemm.sm100.epilogue_tma_store`** — handles transform + partition/dcopy. DO NOT hand-roll. -5. **`get_num_tmem_alloc_cols`** — correct TMEM allocation (accepts list of fragments, sums cols, rounds to power of 2) -6. **`smem.allocate_tensor()`** — for SMEM tensors (not SharedStorage struct for A/B/C) -7. **`LayoutEnum.from_tensor(a).mma_major_mode()`** — major mode from cute tensor -8. **Minimum valid N tile for tcgen05.mma BF16**: 32 (step 32, range 32-256) +### Other APIs +1. `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)` — CuTe tensor from PyTorch +2. `PipelineTmaUmma.create(...).make_participants()` — returns (producer, consumer) pair +3. `utils.gemm.sm100.epilogue_tma_store` — handles transform + partition/dcopy. DO NOT hand-roll. +4. `smem.allocate_tensor()` — for SMEM tensors +5. `LayoutEnum.from_tensor(a).mma_major_mode()` — major mode from cute tensor ## Environment @@ -247,15 +170,6 @@ make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode, - **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel` - **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4` - **vLLM repo**: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell) -- **Pseudocode**: `/root/fragile-kernel-example/README.md` — authoritative per-tile attention flow +- **Pseudocode**: `/root/fragile-kernel-example/README.md` - **fmha.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` - **fmha_bwd.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py` - -## 4-Stage Build Plan - -| Stage | Goal | Status | -|-------|------|--------| -| A | Bare Q@K^T via tcgen05.mma → TMEM → GMEM | ✅ COMPLETE | -| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 A-fragment store compiles, register data flow needs fixing | -| C | Online softmax between MMA1 and MMA2 (the hard part) | ⬜ TODO | -| D | FP8 paged KV gather + dequant (replace BF16 TMA load) | ⬜ TODO | diff --git a/tests/test_mma_si_only.py b/tests/test_mma_si_only.py new file mode 100644 index 00000000..83a86319 --- /dev/null +++ b/tests/test_mma_si_only.py @@ -0,0 +1,247 @@ +""" +Minimal test: Stage A + mma_si pipeline (no PV, no V). +If this deadlocks, the mma_si pipeline is broken. +If this passes, the deadlock is caused by adding V/PV. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class MmaSiTest: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, tiled_mma): + mma_inst_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + self.mma_tiler = (*self.mma_tiler_mn, mma_inst_k * 4) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (tiled_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], self.mma_tiler[2]) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(tiled_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(tiled_mma, self.mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(tiled_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = a.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + tiled_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + self._setup(tiled_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id), + a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id), + b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(tiled_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, tiled_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(tiled_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] # ADDED: mma_si + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + # ADDED: mma_si pipeline (same as v27) + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + thr_mma = tiled_mma.get_slice(0) + tCgA = thr_mma.partition_A(gA); tCgB = thr_mma.partition_B(gB); tCgC = thr_mma.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = tiled_mma.make_fragment_A(sA); tCrB = tiled_mma.make_fragment_B(sB) + acc_shape = thr_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # TMA WARP + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_mma_si_pv.py b/tests/test_mma_si_pv.py new file mode 100644 index 00000000..26118bec --- /dev/null +++ b/tests/test_mma_si_pv.py @@ -0,0 +1,345 @@ +""" +Stage B test: MMA + mma_si + V TMA + PV MMA. +Built incrementally from working test_mma_si_only. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class MmaSiPvTest: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + + cute.size_in_bytes(self.q_dtype, v_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + + tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, + tma_c, tma_tc, self.cluster_layout_vmnk, + self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, + tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) + cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None)) + gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gQ, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ═══ TMA LOAD WARP ═══ + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_pv_mma_mn_major.py b/tests/test_pv_mma_mn_major.py new file mode 100644 index 00000000..066b8877 --- /dev/null +++ b/tests/test_pv_mma_mn_major.py @@ -0,0 +1,303 @@ +""" +Isolated test for Bug 1: PV MMA with V MN-major. + +Only tests the PV MMA (P@V) with V as MN-major B-operand. +No QK MMA, no identity softmax, no pipeline complexity. +P comes from TMEM (a_source=TMEM), V comes from SMEM (b from TMA load). + +Architecture: + - TMA load V into SMEM + - P pre-populated in TMEM (via small QK MMA or direct write) + - PV MMA: P @ V → O in TMEM + - Epilogue: TMEM → GMEM + +For simplicity, P is computed via a QK MMA first (Q@K^T → P in TMEM), +then PV MMA uses P from TMEM. No softmax — identity pass-through. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class PvMmaTest: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.mma_warp_id = 0 + self.tma_warp_id = 1 + self.threads_per_cta = 64 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[pv_test] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[pv_test] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + + # Compute epilogue tile from PV output (not QK) + cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 0 # P = S (identity softmax, same TMEM) + self.tmem_o0_offset = s_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = q.element_type; self.b_dtype = k.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + print(f"[pv_test] a_major (Q) = {self.a_major}") + print(f"[pv_test] b_major (K) = {self.b_major}") + print(f"[pv_test] v_major (V) = {self.v_major}") + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + # BUG 1 FIX: PV MMA uses V's MN-major mode + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + + tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, + tma_c, tma_tc, self.cluster_layout_vmnk, + self.a_smem_s, self.b_smem_s, self.v_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, + tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32), + tx_count=self.num_tmama_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=64) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=0, is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gQ = cute.local_tile(mQ, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gK = cute.local_tile(mK, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gQ, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) + + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P from S TMEM — same location, MMA A-operand for PV + tP = cute.make_tensor(tStS.iterator, self.p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = tOrP # P is at same TMEM offset as S (identity softmax) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # WARP 1: TMA load + if warp_idx == self.tma_warp_id: + tmem.wait_for_alloc() + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1 O[m, n] + # This is P @ V^T in matrix notation + # So reference: Q@K^T @ V^T where V^T is (128, 64) + ref = qf @ kf.T @ vf.T # (128,128) @ (128,64) = (128,64) + + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = PvMmaTest(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('PV MMA test (V MN-major, no softmax):') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_softmax_only.py b/tests/test_softmax_only.py new file mode 100644 index 00000000..3de3f670 --- /dev/null +++ b/tests/test_softmax_only.py @@ -0,0 +1,288 @@ +""" +Test: QK + softmax packing only (no PV). +Output is the softmax-packed P (BF16) stored to GMEM via epilogue. +This tests whether the FP32→BF16 packing in TMEM works correctly. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class SoftmaxOnlyKernel: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, tiled_mma): + mma_inst_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + self.mma_tiler = (*self.mma_tiler_mn, mma_inst_k * 4) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (tiled_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], self.mma_tiler[2]) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(tiled_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(tiled_mma, self.mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100") + self.tilePlikeFP32 = self.mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # BF16 P location in TMEM + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(tiled_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = a.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + tiled_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + self._setup(tiled_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id), + a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id), + b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(tiled_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, tiled_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(tiled_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + thr_mma = tiled_mma.get_slice(0) + tCgA = thr_mma.partition_A(gA); tCgB = thr_mma.partition_B(gB); tCgC = thr_mma.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = tiled_mma.make_fragment_A(sA); tCrB = tiled_mma.make_fragment_B(sB) + acc_shape = thr_mma.partition_shape_C(self.mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # S in TMEM + tStS = thr_mma.make_fragment_C(acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # TMA WARP + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v13.py b/tests/test_stage_b_v13.py new file mode 100644 index 00000000..a6fd803b --- /dev/null +++ b/tests/test_stage_b_v13.py @@ -0,0 +1,401 @@ +""" +Stage B v13: Two MMAs + Identity Softmax using FMHA's C-fragment packing pattern. + +The key insight: the C→A "transform" is NOT a reordering — it's a PACKING. +When you write 128 BF16 values packed into 64 FP32 words via the C-fragment +composition (128, tilePlikeFP32) with St32x32bOp as FP32, the physical TMEM +locations used are exactly the ones the PV MMA's A-fragment reads from. + +FMHA pattern: +1. Load S from TMEM via C-fragment layout (FP32, 128×128) +2. Convert to BF16 and pack: FP32 backing tensor + BF16 recast view +3. Store packed FP32 backing to TMEM via C-fragment composition with St32x32bOp(FP32) +4. PV MMA reads from A-fragment TMEM — same physical locations as the packed BF16 + +Architecture: + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 packed → tcgen05.st C-fragment composition + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + # FMHA pattern: pv_mma_tiler = (M, QK_K, QK_N) — K=QK_N, N=head_dim + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # TMEM offsets (matching fmha.py) + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + + # FMHA TMEM layout: S0=0, P0=32, O0=256 + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + self.tmem_alloc_cols = s_cols + o_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols = {s_cols}, o_cols = {o_cols}") + print(f"[StageB] tmem_alloc_cols = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares SMEM with B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + # ── TMEM tensors ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # ── P A-fragment for PV MMA (matching fmha.py exactly) ── + # tP uses C-fragment iterator but A-fragment (p_tmem_s) layout + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # ── Softmax store uses C-fragment composition (FMHA pattern) ── + # tStS_P: C-fragment layout composed with (128, tilePlikeFP32) + # This is where we store the packed BF16 P values + tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tOrP.layout: {tOrP.layout}') + print(f'[DIAG] tOrP0.layout: {tOrP0.layout}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[DIAG] pv_mma_tiler: {self.pv_mma_tiler}') + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v14.py b/tests/test_stage_b_v14.py new file mode 100644 index 00000000..ecd90772 --- /dev/null +++ b/tests/test_stage_b_v14.py @@ -0,0 +1,352 @@ +""" +Stage B v14: Two MMAs + Identity Softmax using backward FMHA's A-fragment store pattern. + +The backward FMHA writes P to TMEM using the A-fragment layout directly: +1. Load S from TMEM via C-fragment layout (FP32) +2. Quantize FP32 -> BF16 (make_rmem_tensor with BF16, load/store) +3. Reshape quantized to match store coordinate shape +4. Store via St32x32bOp(BF16) to A-fragment TMEM layout +5. PV MMA reads from the same A-fragment addresses +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + self.tmem_alloc_cols = s_cols + o_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + # TMEM tensors + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P A-fragment (backward FMHA pattern) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tdVrP_base = pv_thr.make_fragment_A(tP) + tdVrP = tdVrP_base[(None, None, None, 0)] + tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) + tdVrP = cute.make_tensor(tdVrP_iter, tdVrP.layout) + tdVrP0 = cute.make_tensor( + tdVrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tdVrP.layout) + + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tdVrP.layout: {tdVrP.layout}') + print(f'[DIAG] tdVrP0.layout: {tdVrP0.layout}') + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # TMA WARP + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1 BF16 (backward FMHA pattern) + tRT_rST_bf16 = cute.make_rmem_tensor(tTR_rST.shape, self.q_dtype) + frg_cnt = 4 + frg_tile = cute.size(tTR_rST) // frg_cnt + tTR_rST_frg = cute.logical_divide(tTR_rST, cute.make_layout(frg_tile)) + tRT_rST_bf16_frg = cute.make_tensor(tRT_rST_bf16.iterator, tTR_rST_frg.layout) + for j in range(frg_cnt): + frg_vec = tTR_rST_frg[None, j].load() + tRT_rST_bf16_frg[None, j].store(frg_vec.to(self.q_dtype)) + + # 5. Reshape to match store coordinate shape + tRT_rST_reshaped = cute.make_tensor( + tRT_rST_bf16.iterator, cute.make_layout(tRT_cS.shape)) + + # 6. STORE to A-fragment TMEM + cute.copy(tiled_tmem_store, tRT_rST_reshaped, tRT_tP) + cute.arch.fence_view_async_tmem_store() + + si_handle.release() + + # Epilogue + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, + epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + c_pipe.producer_tail() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +def test(): + torch.manual_seed(42) + m, n, k = 128, 128, 128 + q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float() + ref = qf @ kvf.T @ kvf + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v14: backward FMHA A-fragment store pattern (identity softmax)') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v16.py b/tests/test_stage_b_v16.py new file mode 100644 index 00000000..6b3a95be --- /dev/null +++ b/tests/test_stage_b_v16.py @@ -0,0 +1,453 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Original + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = 512 # FMHA: allocate max TMEM + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + # Introspect PV MMA atom + print(f"[ATOM] PV MMA type: {type(pv_mma)}") + print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}") + print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}") + print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}") + # Check a_src + print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}") + print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}") + print(f"[ATOM] PV MMA op: {pv_mma.op}") + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tdVrP0 = cute.make_tensor(tdVrP.iterator + dtype_width_ratio * tmem_p0_offset, tdVrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tdVrP_base = pv_thr.make_fragment_A(tP) + tdVrP = tdVrP_base[(None, None, None, 0)] + tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) + tdVrP = cute.make_tensor(tdVrP_iter, tdVrP.layout) + tdVrP0 = cute.make_tensor( + tdVrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tdVrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tdVrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tdVrP.layout: {tdVrP.layout}') + print(f'[LAYOUT] PV A-fragment tdVrP cosize: {cute.cosize(tdVrP.layout)}') + print(f'[LAYOUT] PV A-fragment tdVrP.size: {cute.size(tdVrP)}') + print(f'[LAYOUT] PV A-fragment tdVrP0.layout: {tdVrP0.layout}') + print(f'[LAYOUT] PV A-fragment tdVrP0 cosize: {cute.cosize(tdVrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tdVrP0.size = {cute.size(tdVrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1 BF16 (backward FMHA pattern) + tRT_rST_bf16 = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTR_rST_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tRT_rST_bf16_frg = cute.make_tensor(tRT_rST_bf16.iterator, tTR_rST_frg.layout) + for j in range(frg_cnt): + frg_vec = tTR_rST_frg[None, j].load() + tRT_rST_bf16_frg[None, j].store(frg_vec.to(self.q_dtype)) + # 6. Reshape and store to A-fragment TMEM + tRT_rST_reshaped = cute.make_tensor( + tRT_rST_bf16.iterator, cute.make_layout(tRT_cS.shape)) + cute.copy(tiled_tmem_store, tRT_rST_reshaped, tRT_tP) + cute.arch.fence_view_async_tmem_store() + + # 7. Release back to MMA warp + si_handle.release() + + # ── Epilogue ── + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, + epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + c_pipe.producer_tail() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +def test(): + torch.manual_seed(42) + m, n, k = 128, 128, 128 + q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float() + ref = qf @ kvf.T @ kvf + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v17.py b/tests/test_stage_b_v17.py new file mode 100644 index 00000000..6bff0212 --- /dev/null +++ b/tests/test_stage_b_v17.py @@ -0,0 +1,450 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Original + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + # Introspect PV MMA atom + print(f"[ATOM] PV MMA type: {type(pv_mma)}") + print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}") + print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}") + print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}") + # Check a_src + print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}") + print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}") + print(f"[ATOM] PV MMA op: {pv_mma.op}") + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v18.py b/tests/test_stage_b_v18.py new file mode 100644 index 00000000..faffa755 --- /dev/null +++ b/tests/test_stage_b_v18.py @@ -0,0 +1,452 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Original + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = 512 # FMHA-style: allocate max TMEM # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + # Introspect PV MMA atom + print(f"[ATOM] PV MMA type: {type(pv_mma)}") + print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}") + print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}") + print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}") + # Check a_src + print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}") + print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}") + print(f"[ATOM] PV MMA op: {pv_mma.op}") + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1 BF16 (backward FMHA pattern) + tRT_rST_bf16 = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype) + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTR_rST_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tRT_rST_bf16_frg = cute.make_tensor(tRT_rST_bf16.iterator, tTR_rST_frg.layout) + for j in range(frg_cnt): + frg_vec = tTR_rST_frg[None, j].load() + tRT_rST_bf16_frg[None, j].store(frg_vec.to(self.q_dtype)) + + # 6. Reshape and store to A-fragment TMEM + tRT_rST_reshaped = cute.make_tensor( + tRT_rST_bf16.iterator, cute.make_layout(tRT_cS.shape)) + cute.copy(tiled_tmem_store, tRT_rST_reshaped, tRT_tP) + cute.arch.fence_view_async_tmem_store() + + # 7. Release back to MMA warp + si_handle.release() + + # ── Epilogue ── + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, + epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + c_pipe.producer_tail() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +def test(): + torch.manual_seed(42) + m, n, k = 128, 128, 128 + q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float() + ref = qf @ kvf.T @ kvf + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v19.py b/tests/test_stage_b_v19.py new file mode 100644 index 00000000..a523ee6e --- /dev/null +++ b/tests/test_stage_b_v19.py @@ -0,0 +1,450 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Original + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + # Introspect PV MMA atom + print(f"[ATOM] PV MMA type: {type(pv_mma)}") + print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}") + print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}") + print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}") + # Check a_src + print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}") + print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}") + print(f"[ATOM] PV MMA op: {pv_mma.op}") + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v20.py b/tests/test_stage_b_v20.py new file mode 100644 index 00000000..dbd704e3 --- /dev/null +++ b/tests/test_stage_b_v20.py @@ -0,0 +1,362 @@ +""" +Stage B v20: FMHA-matching test with head_dim=64, proper V layout. + +KEY INSIGHT: The A-fragment ((128,16),1,(4,2)):((65536,1),0,(16,64)) is SEQUENTIAL +when flattened in CuTe order: addr = m*65536 + k0 + 16*k1 + 64*k2 = m*65536 + k. +So the C-fragment composition store aliases the SAME TMEM as the A-fragment read. + +Previous -0.02 cosine was caused by V dimension mismatch: +pv_mma_tiler=(128,64,128) expects V with N=64 (head_dim), +but the square 128x128 test had V=K (N=128). + +This test: Q=(128,64), K=(128,64), V=(64,128), O=(128,64) + +FOOTGUN: St32x32bOp MUST use Float32, NOT BFloat16! +The 16-bit values are packed via recast_ptr view. +St32x32bOp(BFloat16) causes ILLEGAL MEMORY ACCESS. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v20] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v20] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + self.tmem_alloc_cols = s_cols + o_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + # Separate TMA for V (uses pv_mma_tiler and v_smem layout) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_v, tma_tv, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_v, mV, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + # V partition (from mV with pv_mma_tiler and tma_v) + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1BF16 packing (EXACT FMHA pattern) + tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype) + tTMEM_STORErS_x4_e = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype), + tTMEM_LOADrS.layout) + + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tTMEM_STORErS_x4_e_frg = cute.logical_divide( + tTMEM_STORErS_x4_e, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + + cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4) + cute.arch.fence_view_async_tmem_store() + si_handle.release() + + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.TmaStorePipeline.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, + epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + c_pipe.producer_tail() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +def test(): + torch.manual_seed(42) + # FMHA-matching dimensions: head_dim=64, seq=128 + # Q: (128, 64) K-major, K: (128, 64) K-major + # V: (64, 128) MN-major (transposed!) — FMHA requires v_major_mode=OperandMajorMode.MN + m, n, head_dim = 128, 128, 64 + q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(head_dim, n, 1, dtype=torch.bfloat16, device='cuda') # Transposed! + c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float(); vf = v[:,:,0].float() + # Q@K^T = (128,128), P@V = (128,64) + ref = qf @ kvf.T @ vf + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v20: FMHA-matching head_dim=64, proper V layout') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v22.py b/tests/test_stage_b_v22.py new file mode 100644 index 00000000..bc711746 --- /dev/null +++ b/tests/test_stage_b_v22.py @@ -0,0 +1,314 @@ +""" +Stage B v22: Q@K^T → S in TMEM, P@V → O in TMEM (identity softmax = S used as P) + +Based on working Stage A v2 pattern. Two MMAs on the MMA warp, no softmax pipeline. +P stays in TMEM between QK and PV — no copy, no packing, just use S directly as P. + +Bug 1 fix: V is MN-major, PV MMA uses v_major (OperandMajorMode.MN). +The PV MMA's A-operand (P) comes from TMEM (same as S accumulator). +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentityKernel: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = self.o_dtype # alias for epilogue_tma_store + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v22] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v22] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tmem_s0_offset = 0 + self.tmem_o0_offset = s_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + cute.size_in_bytes(self.b_dtype, v_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = q.element_type; self.b_dtype = k.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + print(f"[v22] a_major (Q) = {self.a_major}") + print(f"[v22] b_major (K) = {self.b_major}") + print(f"[v22] v_major (V) = {self.v_major}") + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + # BUG 1 FIX: PV MMA b_leading_mode = v_major (MN), NOT b_major (K) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, # Bug 1 fix + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + + tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, + tma_c, tma_tc, self.cluster_layout_vmnk, + self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, + tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) + cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None)) + gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gQ, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + # V partition — pv_mma with pv_mma_tiler + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + # QK accumulator (S) in TMEM + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + # PV accumulator (O) in TMEM + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P tensor for PV MMA A-operand (from TMEM, same as S) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + # P is at the same TMEM location as S (identity softmax: no copy/packing) + tOrP0 = tOrP + + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ═══ TMA LOAD WARP ═══ + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v22_bug1fix.py b/tests/test_stage_b_v22_bug1fix.py new file mode 100644 index 00000000..e6fdb8c3 --- /dev/null +++ b/tests/test_stage_b_v22_bug1fix.py @@ -0,0 +1,364 @@ +""" +Stage B v22: Bug 1 Fix — V B-Operand Must Be MN-Major + +Fix over v20: PV MMA uses V's major mode (MN) instead of K's major mode (K). +V is shaped (head_dim, seq) = (64, 128) with strides (1, 64) → OperandMajorMode.MN. +K is shaped (seq, head_dim) = (128, 64) with strides (64, 1) → OperandMajorMode.K. + +These are DIFFERENT. The PV MMA's b_leading_mode MUST come from V, not from K. + +Also: separate TMA descriptor for V (already in v20), separate SMEM layout for V. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v22] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v22] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + # V uses pv_mma with MN-major B — its own SMEM layout + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + self.tmem_alloc_cols = s_cols + o_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() # V's major mode — MN! + self.c_layout = LayoutEnum.from_tensor(c) + + print(f"[v22] a_major (Q) = {self.a_major}") + print(f"[v22] b_major (K) = {self.b_major}") + print(f"[v22] v_major (V) = {self.v_major}") + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + # BUG 1 FIX: PV MMA uses V's major mode (MN), NOT K's major mode (K) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + # Separate TMA for V — uses pv_mma and pv_mma_tiler + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_v, tma_tv, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_v, mV, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + # V partition — uses pv_mma, pv_mma_tiler, and tma_v + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1BF16 packing (EXACT FMHA pattern) + tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype) + tTMEM_STORErS_x4_e = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype), + tTMEM_LOADrS.layout) + + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tTMEM_STORErS_x4_e_frg = cute.logical_divide( + tTMEM_STORErS_x4_e, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + + cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4) + cute.arch.fence_view_async_tmem_store() + si_handle.release() + + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, + epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + c_pipe.producer_tail() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +def test(): + torch.manual_seed(42) + # FMHA-matching dimensions: head_dim=64, seq=128 + # Q: (128, 64) K-major, K: (128, 64) K-major + # V: (64, 128) MN-major — FMHA requires v_major_mode=OperandMajorMode.MN + m, n, head_dim = 128, 128, 64 + q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda') + v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda') + v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1) # MN-major: strides (1, 64) + c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float(); vf = v[:,:,0].float() + # Q@K^T = (128,128), P@V = (128,64) + ref = qf @ kvf.T @ vf.T + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v22: Bug 1 fix — V MN-major') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v23.py b/tests/test_stage_b_v23.py new file mode 100644 index 00000000..6a9264ef --- /dev/null +++ b/tests/test_stage_b_v23.py @@ -0,0 +1,364 @@ +""" +Stage B v23: FMHA-matching test with head_dim=64, proper V layout. + +KEY INSIGHT: The A-fragment ((128,16),1,(4,2)):((65536,1),0,(16,64)) is SEQUENTIAL +when flattened in CuTe order: addr = m*65536 + k0 + 16*k1 + 64*k2 = m*65536 + k. +So the C-fragment composition store aliases the SAME TMEM as the A-fragment read. + +Previous -0.02 cosine was caused by V dimension mismatch: +pv_mma_tiler=(128,64,128) expects V with N=64 (head_dim), +but the square 128x128 test had V=K (N=128). + +This test: Q=(128,64), K=(128,64), V=(64,128), O=(128,64) + +FOOTGUN: St32x32bOp MUST use Float32, NOT BFloat16! +The 16-bit values are packed via recast_ptr view. +St32x32bOp(BFloat16) causes ILLEGAL MEMORY ACCESS. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v23] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v23] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + self.tmem_alloc_cols = s_cols + o_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() # Bug 1: V MN-major + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, # Bug 1 fix: V MN-major + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + # Separate TMA for V (uses pv_mma_tiler and v_smem layout) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_v, tma_tv, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_v, mV, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + # V partition (from mV with pv_mma_tiler and tma_v) + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1BF16 packing (EXACT FMHA pattern) + tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype) + tTMEM_STORErS_x4_e = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype), + tTMEM_LOADrS.layout) + + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tTMEM_STORErS_x4_e_frg = cute.logical_divide( + tTMEM_STORErS_x4_e, cute.make_layout(frg_tile)) + for j in range(frg_cnt): + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + + cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4) + cute.arch.fence_view_async_tmem_store() + si_handle.release() + + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, + epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + c_pipe.producer_tail() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +def test(): + torch.manual_seed(42) + # FMHA-matching dimensions: head_dim=64, seq=128 + # Q: (128, 64) K-major, K: (128, 64) K-major + # V: (64, 128) MN-major (transposed!) — FMHA requires v_major_mode=OperandMajorMode.MN + m, n, head_dim = 128, 128, 64 + q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda') + v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda') + v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1) # MN-major: strides (1, 64) + c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float(); vf = v_base.float() + # Q@K^T = (128,128), P@V = (128,64) + ref = qf @ kvf.T @ vf.T # P@V with V MN-major = Q@K^T @ V^T + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v23: FMHA-matching head_dim=64, proper V layout') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v24.py b/tests/test_stage_b_v24.py new file mode 100644 index 00000000..ee2a94f6 --- /dev/null +++ b/tests/test_stage_b_v24.py @@ -0,0 +1,377 @@ +""" +Stage B v24: Q@K^T + identity softmax P packing + P@V +Bug 1 fix: V MN-major +Bug 2 fix: FP32->BF16 P packing (C-fragment composition store) +Pipeline fix: Use NamedBarrier instead of PipelineUmmaAsync for mma_si sync +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentityKernel: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + # Named barrier IDs for mma_si coordination + self.scores_ready_bar_id = 4 # MMA warp signals S ready + self.softmax_done_bar_id = 5 # Epilogue warps signal softmax done + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v24] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v24] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # P region in TMEM (for BF16 packing) + self.tmem_o0_offset = s_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = q.element_type; self.b_dtype = k.element_type + self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + print(f"[v24] a_major (Q) = {self.a_major}") + print(f"[v24] b_major (K) = {self.b_major}") + print(f"[v24] v_major (V) = {self.v_major}") + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + + tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, + tma_c, tma_tc, self.cluster_layout_vmnk, + self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, + tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) + cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + # Named barriers for MMA ↔ softmax coordination + scores_ready_bar = pipeline.NamedBarrier( + barrier_id=self.scores_ready_bar_id, + num_threads=32 * (1 + len(self.epilogue_warp_id))) # MMA warp + epilogue warps + softmax_done_bar = pipeline.NamedBarrier( + barrier_id=self.softmax_done_bar_id, + num_threads=32 * (1 + len(self.epilogue_warp_id))) # MMA warp + epilogue warps + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None)) + gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gQ, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P from TMEM — at tmem_p0_offset, read by PV MMA as A-operand + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ═══ TMA LOAD WARP ═══ + if warp_idx == self.tma_warp_id: + tmem.wait_for_alloc() + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1 MMA computes P @ V^T + ref = qf @ kf.T @ vf.T # (128,64) + + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentityKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v24: Q@K^T + identity softmax (NamedBarrier) + P@V (V MN-major)') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v25.py b/tests/test_stage_b_v25.py new file mode 100644 index 00000000..da97566f --- /dev/null +++ b/tests/test_stage_b_v25.py @@ -0,0 +1,380 @@ +""" +Stage B v25: Q@K^T + identity softmax P packing + P@V +Uses PipelineAsync (not Umma) for mma_si sync — two separate one-shot pipelines. +Bug 1 fix: V MN-major. +Bug 2 fix: FP32→BF16 P packing (C-fragment composition store). +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentityKernel: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v25] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v25] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = q.element_type; self.b_dtype = k.element_type + self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + print(f"[v25] a_major (Q) = {self.a_major}") + print(f"[v25] b_major (K) = {self.b_major}") + print(f"[v25] v_major (V) = {self.v_major}") + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + + tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, + tma_c, tma_tc, self.cluster_layout_vmnk, + self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, + tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) + cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + # Two separate mbarriers for mma→softmax and softmax→mma signaling + scores_ready_bar: cute.struct.MemRange[cutlass.Int64, 2] # MMA→softmax + softmax_done_bar: cute.struct.MemRange[cutlass.Int64, 2] # softmax→MMA + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + # PipelineAsync for scores_ready (MMA → softmax) + scores_ready_prod, scores_ready_cons = pipeline.PipelineAsync.create( + barrier_storage=st.scores_ready_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + ).make_participants() + + # PipelineAsync for softmax_done (softmax → MMA) + softmax_done_prod, softmax_done_cons = pipeline.PipelineAsync.create( + barrier_storage=st.softmax_done_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + ).make_participants() + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None)) + gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gQ, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ═══ TMA LOAD WARP ═══ + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v26.py b/tests/test_stage_b_v26.py new file mode 100644 index 00000000..9d933015 --- /dev/null +++ b/tests/test_stage_b_v26.py @@ -0,0 +1,367 @@ +""" +Stage B v26: Q@K^T + identity softmax P packing + P@V +Uses SMEM spin-wait flags for mma↔softmax sync (no PipelineUmmaAsync). +Bug 1 fix: V MN-major. +Bug 2 fix: FP32→BF16 P packing (C-fragment composition store). +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentityKernel: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v26] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v26] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = q.element_type; self.b_dtype = k.element_type + self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + + tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, + tma_c, tma_tc, self.cluster_layout_vmnk, + self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, + tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) + cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + # Use acc_pipe's barrier for BOTH QK→softmax and PV→epilogue signaling + # The trick: re-acquire acc_pipe on the producer side after softmax + # This works because PipelineUmmaAsync with 1 stage blocks producer + # until consumer releases + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None)) + gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gQ, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ═══ TMA LOAD WARP (no tmem.wait_for_alloc!) ═══ + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v27.py b/tests/test_stage_b_v27.py new file mode 100644 index 00000000..c5292f5a --- /dev/null +++ b/tests/test_stage_b_v27.py @@ -0,0 +1,369 @@ +""" +Stage B v27: Based on v20 with Bug 1 fix (V MN-major). +Fixes over v20: + 1. V MN-major + pv_mma uses v_major + 2. PipelineTmaStore (not TmaStorePipeline) + 3. TMA warp does NOT call tmem.wait_for_alloc + 4. mma_si PipelineUmmaAsync without cta_layout_vmnk (like FMHA) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[v27] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[v27] pv_mma_tiler = {self.pv_mma_tiler}") + + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.q_dtype = q.element_type; self.b_dtype = k.element_type + self.o_dtype = c.element_type; self.c_dtype = self.o_dtype + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + print(f"[v27] a_major (Q) = {self.a_major}") + print(f"[v27] b_major (K) = {self.b_major}") + print(f"[v27] v_major (V) = {self.v_major}") + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) + + tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), + v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, + tma_c, tma_tc, self.cluster_layout_vmnk, + self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, + tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) + cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + # mma_si pipeline — NO cta_layout_vmnk (matches FMHA) + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None)) + gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gQ, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + pv_thr = pv_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) + tCgV = pv_thr.partition_B(gV) + tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) + tVgV = tVgV[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ═══ TMA LOAD WARP (no tmem.wait_for_alloc) ═══ + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test()