diff --git a/README.md b/README.md index 7608f9fb..69180c5a 100644 --- a/README.md +++ b/README.md @@ -18,15 +18,17 @@ cutedsl/ └── dense_blockscaled_gemm_persistent.py # REFERENCE: Blackwell TMEM/tcgen05 GEMM tests/ -├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM -├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + identity softmax (runs, wrong output) -├── test_stage_b_minimal.py # ✅ Stage B minimal: two MMAs, no softmax (runs, NaN expected) -├── test_stage_b_pipeline_only.py # ✅ Stage B pipeline-only: PipelineUmmaAsync, no ld/st (runs, NaN expected) -├── diag_tmem.py # Diagnostic: TMEM layout inspection -├── test_stage_b_v6.py # ❌ Stage B v6 (hardcoded offsets, crashes) -├── test_stage_a_qk.py # ❌ Stage A v1 (broken, superseded by v2) -├── test_stage_a_minimal.py # ❌ Stage A minimal (broken, superseded by v2) -├── test_attention_path_b200.py # Full attention path test (uses naive BF16 attn) +├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM +├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + C-fragment softmax (runs, wrong output) +├── test_stage_b_afrag2.py # 🔨 Stage B: A-fragment store pattern (compiles, wrong output) +├── test_tmem_pure_fp32.py # ✅ FP32 ld→st roundtrip on C-fragment: cosine 0.999999 +├── test_bf16_elemwise.py # ✅ FP32→BF16→FP32 elemwise + FP32 st: cosine 0.999999 +├── test_recast_minimal.py # ✅ BF16 recast ld S0→st S1 via C-fragment: cosine 0.999999 +├── test_bf16_recast_simple.py # ❌ BF16 recast ld/st same region (S0): zero (can't overwrite MMA output) +├── test_tmem_copy_roundtrip.py # ❌ BF16 recast + C→A mismatch: zero +├── test_stage_b_final.py # ❌ C-fragment st + A-fragment read: NaN (physical layout mismatch) +├── test_afrag_roundtrip.py # ❌ A-frag st corrupts S0 (overlapping TMEM region) +├── diag_tmem.py # Diagnostic: TMEM layout inspection └── ... ``` @@ -46,68 +48,66 @@ Validates the full tcgen05.mma → TMEM → epilogue → GMEM path: - TmemAllocator for TMEM allocation/deallocation - utils.gemm.sm100.epilogue_tma_store for the TMEM→reg→SMEM→TMA→GMEM epilogue -### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20) +### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20-21) -**Latest**: `tests/test_stage_b_v7.py` -**Status**: Kernel compiles and runs without crashing. Identity softmax produces wrong output (cosine ≈ -0.02). +**Core Problem**: The C-fragment (MMA accumulator) and A-fragment (MMA A-operand from TMEM) use **different physical TMEM address mappings** for the same logical (M,K) position. The softmax writes P via one mapping, but the PV MMA reads via the other. This produces garbage. -**What was fixed today:** +#### What's Been Proven -1. **PipelineUmmaAsync consumer group size crash (THE bug):** - `PipelineUmmaAsync` with `Agent.Thread` requires **thread count** (128), NOT warp count (4), for the consumer group. fmha.py uses `32 * len(softmax_warp_ids) = 128`. Using 4 caused `CUDA_ERROR_LAUNCH_FAILED` (not a deadlock — the barrier reached wrong threshold causing illegal TMEM access). - - ```python - # WRONG (caused CUDA_ERROR_LAUNCH_FAILED): - consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) - # CORRECT (matches fmha.py): - consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128) - ``` +| Test | Pattern | Result | Why | +|------|---------|--------|-----| +| test_tmem_pure_fp32 | FP32 ld→st, same C-fragment layout | ✅ cos=0.999999 | C-fragment addresses self-consistent | +| test_bf16_elemwise | FP32→BF16→FP32 elemwise, C-fragment st | ✅ cos=0.999999 | BF16 conversion works, C-fragment st works | +| test_recast_minimal | BF16 recast ld S0→st S1, C-fragment | ✅ cos=0.999999 | Recast works when writing to different region | +| test_bf16_recast_simple | BF16 recast ld/st same region S0 | ❌ zero | Can't overwrite MMA output in same region | +| test_stage_b_final | C-fragment st → A-fragment read (S1) | ❌ NaN | C-layout ≠ A-layout physical addresses | +| test_stage_b_afrag2 | A-fragment st (backward FMHA pattern) | ❌ cos=-0.02 | Store + PV MMA layout compatible, but register data flow wrong | -2. **TMEM offset computation (no more hardcoding):** - - `s_cols = find_tmem_tensor_col_offset(tStS) = 128` — QK accumulator physical TMEM columns - - `o_cols = find_tmem_tensor_col_offset(tOtO) = 128` — PV accumulator physical TMEM columns - - `tmem_s0_offset = 0, tmem_p0_offset = 32, tmem_o0_offset = 128` — matches fmha.py - - `find_tmem_tensor_col_offset(tOrP_sliced) = 32800 = 0x8020` — 0x8000 is TMEM space tag, column offset = 32 - - Total: 256 TMEM cols (verified by `get_num_tmem_alloc_cols`) +#### Root Cause: C-fragment vs A-fragment Physical TMEM Layout -3. **P fragment construction (matching fmha.py):** - ```python - tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) # A-layout from PV MMA - tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] - tOrP0 = cute.make_tensor(tOrP.iterator + 2 * tmem_p0_offset, tOrP.layout) - ``` - Previously used `cute.composition` on C-layout — wrong, must use PV MMA's A-layout. +From the CUTLASS source (`mma_traits_sm100.hpp`): -4. **V SMEM aliasing:** - V shares the same SMEM as K with a different layout interpretation: - ```python - sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) - sV = cute.make_tensor(sV_ptr, v_smem_s.outer) - tCrV = pv_mma.make_fragment_B(sV) # Uses MN-major V layout - ``` +**C-fragment (MMA accumulator, FP32):** +- Layout: `((128,128),1,1):((65536,1),0,0)` — **virtual** layout +- Physical TMEM addresses determined by the MMA hardware's accumulator write path +- St32x32bOp with C-fragment layout writes to C-fragment physical addresses -**What's still broken:** +**A-fragment (MMA A-operand from TMEM, BF16, K-major, M=128):** +- Layout: `((128,16),1,4):((65536,1),0,16)` — **physical** TMEM layout +- A[m, k_inner] → `tmem[dp=m, col=base + 16*mma_k + k_inner]` +- BK=64 = 4 K=16 MMA atoms, NOT one K=64 atom +- The 4D fragment partition order is NOT the physical TMEM order -The identity softmax C→A layout transform produces garbage output (cosine ≈ -0.02). The kernel runs, Stage A (Q@K^T) gives cosine 0.999999, but the full (Q@K^T)@V pipeline is wrong. The issue is in the tcgen05.ld/st identity softmax path — either the ld/st copy atoms, the register conversion, or the A-layout write positions are incorrect. +**The St32x32bOp with C-fragment composition writes to C-layout physical addresses. The PV MMA reads from A-layout physical addresses. These are different physical locations.** -**Bisection results:** -- ✅ Stage B minimal (no pipeline, no softmax): runs, NaN (expected — no C→A transform) -- ✅ Stage B pipeline-only (PipelineUmmaAsync, no ld/st): runs, NaN (expected) -- 🔨 Stage B full (identity softmax): runs, cosine -0.02 (wrong — softmax transform is broken) -- All three crash with consumer_group=4, all run with consumer_group=128 +#### Forward FMHA's Approach (FP16 Only!) -**TMEM layout diagnostic data:** +Forward FMHA uses a recast pattern to pack 2×FP16 into 1×FP32 register, then St32x32bOp writes to a C-fragment composition subview. **But forward FMHA explicitly rejects BF16:** +```python +if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: + raise ValueError(in_dtype must be Float8E4M3FN or Float16) ``` -QK accumulator C fragment: - tStS.layout = ((128,128),1,1):((65536,1),0,0) - cute.size = 16384, cute.cosize = 8323200 - find_tmem_tensor_col_offset = 128 +The recast softmax path is validated for FP16, NOT BF16. Our BF16 use is outside the tested path. -PV A-fragment (P operand): - tOrP_sliced.layout = ((128,16),1,4):((65536,1),0,16) - cute.size = 8192, cute.cosize = 8323136 - find_tmem_tensor_col_offset = 32800 = 0x8020 (0x8000 tag + col 32) -``` +#### Backward FMHA's Approach (BF16 Supported) + +Backward FMHA writes dV to TMEM using the A-fragment layout: +1. `tdVrP_iter = cute.recast_ptr(tSTtST.iterator, dtype=self.element_dtype)` — recast C-fragment iterator to BF16 +2. `tdVrP = cute.make_tensor(tdVrP_iter, tOrP.layout)` — A-fragment layout, C-fragment base +3. `tmem_store_atom = cute.make_copy_atom(St32x32bOp(Repetition(8)), self.element_dtype)` — BF16 store atom +4. Quantize via `make_rmem_tensor(input.shape, element_dtype)` + `.load()/.store(v.to(element_dtype))` — true BF16 register, NOT recast +5. Reshape: `cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape))` — match store partition shape + +This compiles and runs for us (no crash), but the output is still wrong (cosine -0.02). The remaining issue is the **register layout mismatch**: +- Load partition (C-fragment): 128 FP32 values per thread (full 128×128 QK tile) +- Store partition (A-fragment): 64 BF16 values per thread (128×64 P tile for PV MMA K=64) +- The backward FMHA uses `quantize()` + reshape, but our element counts differ because the QK tile is 128×128 while P only needs 128×64 + +#### Next Steps for Stage B + +1. **Fix the register data flow** — properly subselect the P-relevant 64 BF16 columns from the 128 FP32 load columns, or use the backward FMHA's PdO MMA tiler (M=128, N=64) instead of (M=128, N=128) +2. **Verify A-fragment store roundtrip** — write known BF16 values via A-fragment store, have PV MMA read them back via A-fragment, confirm the physical TMEM addresses match +3. **Once data flow is correct, add online softmax** (Stage C) ### 🔨 Stage C: Online Softmax — AFTER B @@ -168,9 +168,20 @@ After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast ## Critical APIs & Lessons -### PipelineUmmaAsync consumer group size — THE MAY 20 BUG +### C-fragment ≠ A-fragment TMEM Physical Layout — THE MAY 20-21 FINDING -**For `Agent.Thread` groups in `PipelineUmmaAsync`: use thread count, NOT warp count.** +**The St32x32bOp with C-fragment composition writes to C-layout physical TMEM addresses. The PV MMA reads from A-layout physical TMEM addresses. These are DIFFERENT physical locations for the same logical (M,K) position.** + +For the softmax to work, P must be written to TMEM using the A-fragment's physical layout, not the C-fragment's. The backward FMHA does this correctly by: +1. Creating the store destination with A-fragment layout + recast C-fragment iterator +2. Using a BF16 St32x32bOp atom +3. True BF16 register (not FP32 recast) via quantize() pattern + +### Forward FMHA Recast Pattern — FP16 ONLY + +The `cute.recast_ptr` + `.store(v.to(FP16))` pattern for packing 2×16-bit into 1×FP32 register is validated for FP16 only. BF16 is rejected in forward FMHA. The BF16 recast produces zero output when writing to the same TMEM region as the MMA output, and NaN when writing to a different region read via A-fragment. + +### PipelineUmmaAsync consumer group size — thread count, NOT warp count ```python # WRONG (caused CUDA_ERROR_LAUNCH_FAILED): @@ -180,18 +191,32 @@ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) # warp count consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(softmax_warp_ids)) # thread count ``` -This applies to ALL PipelineUmmaAsync consumers where the consumer is multiple warps. fmha.py line 671: `self.threads_per_warp * len(self.softmax0_warp_ids) = 32 * 4 = 128`. - -**Note:** The earlier README incorrectly stated that warp count was correct. That was wrong. The `Agent.Thread` agent type measures group size in threads. - ### TMEM offset arithmetic - `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count (with 0x8000 tag for A-fragments) - QK accumulator C fragment: 128 TMEM columns - PV A-fragment: offset 0x8020 = tag(0x8000) + col(32) — the 0x8000 is a TMEM memory-space identifier -- P OVERLAPS S in TMEM — P is written at column 32 within the S region (C-layout columns 0..127) - `tOrP0 = cute.make_tensor(tOrP.iterator + acc_dtype.width // q_dtype.width * tmem_p0_offset, tOrP.layout)` — A-fragment offset scaled by dtype width ratio (F32/BF16 = 2) +### A-fragment iterator must use recast C-fragment pointer + +When creating the P tensor for PV MMA's A-operand, the iterator must be the C-fragment's iterator recast to BF16: +```python +tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) +tP = cute.make_tensor(tP_iter, p_tmem_s.outer) +tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] +``` +Without the recast, the A-fragment addresses are computed from an FP32 pointer base, giving wrong physical TMEM addresses (illegal memory access crash). + +### V SMEM aliasing (K and V share SMEM) + +```python +v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1) +sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) +sV = cute.make_tensor(sV_ptr, v_smem_s.outer) +tCrV = pv_mma.make_fragment_B(sV) +``` + ### `make_trivial_tiled_mma` has two overloads ```python @@ -204,16 +229,6 @@ make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode, acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM) ``` -### V SMEM aliasing (K and V share SMEM) - -```python -# K and V share the same SMEM buffer, but with different layouts: -v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1) -sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) -sV = cute.make_tensor(sV_ptr, v_smem_s.outer) -tCrV = pv_mma.make_fragment_B(sV) -``` - ### Other APIs discovered from Stage A 1. **`cute.Tensor` API** — `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)` @@ -234,12 +249,13 @@ tCrV = pv_mma.make_fragment_B(sV) - **vLLM repo**: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell) - **Pseudocode**: `/root/fragile-kernel-example/README.md` — authoritative per-tile attention flow - **fmha.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` +- **fmha_bwd.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py` ## 4-Stage Build Plan | Stage | Goal | Status | |-------|------|--------| | A | Bare Q@K^T via tcgen05.mma → TMEM → GMEM | ✅ COMPLETE | -| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 Runs without crash, identity softmax produces wrong output | +| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 A-fragment store compiles, register data flow needs fixing | | C | Online softmax between MMA1 and MMA2 (the hard part) | ⬜ TODO | | D | FP8 paged KV gather + dequant (replace BF16 TMA load) | ⬜ TODO | diff --git a/tests/diag_layouts.py b/tests/diag_layouts.py new file mode 100644 index 00000000..1f7eaa1b --- /dev/null +++ b/tests/diag_layouts.py @@ -0,0 +1,85 @@ +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass import Float32, BFloat16 +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset + +a_dtype = BFloat16; b_dtype = BFloat16 +from cutlass.cute.nvgpu import OperandMajorMode +a_major = OperandMajorMode.K +b_major = OperandMajorMode.K +mma_tiler_mn = (128, 128) + +qk_mma = utils.sm100.make_trivial_tiled_mma( + a_dtype, b_dtype, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM) +pv_mma = utils.sm100.make_trivial_tiled_mma( + a_dtype, b_dtype, OperandMajorMode.K, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM) + +qk_thr = qk_mma.get_slice(0) +qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn) +tStS = qk_thr.make_fragment_C(qk_acc_shape) + +pv_thr = pv_mma.get_slice(0) +pv_acc_shape = pv_thr.partition_shape_C(mma_tiler_mn) +tOtO = pv_thr.make_fragment_C(pv_acc_shape) + +qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) +qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4) +pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) +pv_mma_tiler = (*mma_tiler_mn, pv_inst_k * 4) + +p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) +tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) +tOrP_base = pv_thr.make_fragment_A(tP) +tOrP = tOrP_base[(None, None, None, 0)] +tOrP0 = cute.make_tensor( + tOrP.iterator + Float32.width // BFloat16.width * 32, + tOrP.layout) + +print('=== Layout diagnostics ===') +print('tStS.layout:', tStS.layout) +print('tStS.size:', cute.size(tStS)) +print('tStS s_cols:', find_tmem_tensor_col_offset(tStS)) +print() +print('tOtO.layout:', tOtO.layout) +print('tOtO.size:', cute.size(tOtO)) +print('tOtO o_cols:', find_tmem_tensor_col_offset(tOtO)) +print() +print('tOrP.layout:', tOrP.layout) +print('tOrP.size:', cute.size(tOrP)) +print('tOrP0.layout:', tOrP0.layout) +print('tOrP0.size:', cute.size(tOrP0)) +print() + +tilePlikeFP32 = 128 * 16 // 32 +tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) +print('tStS_P_layout:', tStS_P_layout) +print() + +# LOAD +tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32) +tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS) +thr_load = tiled_tmem_load.get_slice(0) +cS = cute.make_identity_tensor((128, 128)) +tScS = qk_thr.partition_C(cS) +tTMEM_LOADcS = thr_load.partition_D(tScS) +print('LOAD tTMEM_LOADcS.shape:', tTMEM_LOADcS.shape) +print('LOAD per-thread elements:', cute.size(tTMEM_LOADcS)) + +# STORE (composition) +tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout) +tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32) +tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) +thr_store = tiled_tmem_store.get_slice(0) +tTMEM_STOREcS = thr_store.partition_S(cute.make_identity_tensor(tStS_P.shape)) +print('STORE tTMEM_STOREcS.shape:', tTMEM_STOREcS.shape) +print('STORE per-thread elements:', cute.size(tTMEM_STOREcS)) + +# What about the tOrP0 shape for store? +print() +print('tOrP0.shape:', tOrP0.shape if hasattr(tOrP0, 'shape') else 'N/A') +# tOrP0 is BF16 so we'd need a BF16 store atom - but cute.copy requires equal bit widths +# The F32 store to a BF16 target doesn't work either +# This is the fundamental tension diff --git a/tests/test_afrag_roundtrip.py b/tests/test_afrag_roundtrip.py new file mode 100644 index 00000000..12237293 --- /dev/null +++ b/tests/test_afrag_roundtrip.py @@ -0,0 +1,175 @@ +"""Test: ld FP32 from S0, st BF16 to A-fragment layout tdVrP, +ld BF16 back from tdVrP, epi the result. +If this works, the A-fragment store is correct and the issue is in the PV MMA.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class AFragRoundtrip: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols # Only need S region + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + # A-fragment for pv_mma + pv_thr = pv_mma.get_slice(0) + tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) + tP = cute.make_tensor(tP_iter, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + tdVrP = cute.make_tensor(tOrP.iterator, tOrP.layout) + # TMEM ld (C-fragment) + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + # TMEM st (A-fragment layout, BF16) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP) + thr_st = tiled_st.get_slice(sfw) + tStP = thr_st.partition_D(tdVrP) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + # TMA + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_bf16_elemwise.py b/tests/test_bf16_elemwise.py new file mode 100644 index 00000000..8899958c --- /dev/null +++ b/tests/test_bf16_elemwise.py @@ -0,0 +1,237 @@ +"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1. +No recast, no BF16, no packing. Pure FP32 copy between TMEM regions.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class BF16Elemwise: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols * 2 + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) + + # LD and ST on same layout + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + tStS = thr_st.partition_D(tStS1) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + tStcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_bf16_pack.py b/tests/test_bf16_pack.py new file mode 100644 index 00000000..e12596ac --- /dev/null +++ b/tests/test_bf16_pack.py @@ -0,0 +1,237 @@ +"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1. +No recast, no BF16, no packing. Pure FP32 copy between TMEM regions.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class BF16PackTest: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols * 2 + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) + + # LD and ST on same layout + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + tStS = thr_st.partition_D(tStS1) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + tStcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_bf16_recast_full.py b/tests/test_bf16_recast_full.py new file mode 100644 index 00000000..1ca6ce8e --- /dev/null +++ b/tests/test_bf16_recast_full.py @@ -0,0 +1,242 @@ +"""Test BF16 recast pattern with FULL C-fragment layout (no subview). +ld from S0 (full 128x128), recast BF16, st to S1 (full 128x128 at offset 128). +Since both ld and st use the same layout, the recast should work (shapes match). +Then epi reads S1. If this works, the recast pattern IS correct for same-layout cases. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class BF16RecastFull: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols * 2 + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) + + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + tStS1_dst = thr_st.partition_D(tStS1) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + tStcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_bf16_recast_simple.py b/tests/test_bf16_recast_simple.py new file mode 100644 index 00000000..b6588c30 --- /dev/null +++ b/tests/test_bf16_recast_simple.py @@ -0,0 +1,239 @@ +"""Simplest BF16 recast test: ld FP32 from S0, use recast to convert, +st back to S0, then epi reads S0. Single region, no subview.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class SimpleBF16Recast: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + + # ld and st on the SAME tensor (S0), same layout + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS0) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + tStS_dst = thr_st.partition_D(tStS0) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + tStcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_error_pattern.py b/tests/test_error_pattern.py new file mode 100644 index 00000000..3c7632df --- /dev/null +++ b/tests/test_error_pattern.py @@ -0,0 +1,85 @@ +""" +BF16 Packing Diagnostic: Run identity softmax with K=V=randn, +compare output vs reference to identify the error pattern. +""" +import torch, sys +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests') + +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from test_stage_b_v7 import StageBIdentitySoftmax + +torch.manual_seed(42) +m, n, k = 128, 128, 128 +q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') +kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') +c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + +qf = q[:,:,0].float() +kvf = kv[:,:,0].float() +ref = qf @ kvf.T @ kvf # identity softmax: (Q @ K^T) @ V + +mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) +mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) +mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) +stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + +kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) +print('Compiling...', flush=True) +compiled = cute.compile(kernel, mQ, mK, mC, stream) +print('Running...', flush=True) +compiled(mQ, mK, mC, stream) +torch.cuda.synchronize() + +out = c[:,:,0].float() +cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + +print(f'\nCosine: {cos:.6f}') +print(f'Output row 0[:8]: {out[0,:8].tolist()}') +print(f'Ref row 0[:8]: {ref[0,:8].tolist()}') + +# Key diagnostic: compare Q@K^T stage (which we know is correct) vs PV stage +# If Q@K^T is correct but PV is wrong, the output will show the PV error pattern +# With identity softmax, P = Q@K^T. So output = P @ V = (Q@K^T) @ V +# If V is being read wrong, the output will be P @ V_permuted + +# Check: is the output a permutation of the reference? +# Sort both and compare +out_sorted = out.flatten().sort()[0] +ref_sorted = ref.flatten().sort()[0] +cos_sorted = torch.nn.functional.cosine_similarity(out_sorted.unsqueeze(0), ref_sorted.unsqueeze(0)).item() +print(f'\nCosine after sorting: {cos_sorted:.6f}') +print(f'(If ~1.0, output is a permutation of reference)') + +# Check: does output match P @ V^T? (V transposed) +ref_vt = qf @ kvf.T @ kvf.T # (Q@K^T) @ V^T +cos_vt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_vt.flatten().unsqueeze(0)).item() +print(f'Cosine with P @ V^T: {cos_vt:.6f}') + +# Check: does output match P^T @ V? (P transposed) +# P = Q@K^T, so P^T = K@Q^T +ref_pt = kvf @ qf.T @ kvf +cos_pt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_pt.flatten().unsqueeze(0)).item() +print(f'Cosine with P^T @ V: {cos_pt:.6f}') + +# Check: is output simply the Q@K^T scores (MMA2 produced identity)? +# If MMA2 didn't run or produced P unchanged, output = P = Q@K^T +qkt = qf @ kvf.T +cos_qkt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), qkt.flatten().unsqueeze(0)).item() +print(f'Cosine with just Q@K^T: {cos_qkt:.6f}') + +# Check: all output rows identical? (means P has identical rows, like all-ones) +all_same = torch.allclose(out[0], out[1], atol=1e-3) +print(f'All output rows identical: {all_same}') +if not all_same: + cos_r01 = torch.nn.functional.cosine_similarity(out[0].unsqueeze(0), out[1].unsqueeze(0)).item() + print(f'Cosine between row 0 and row 1: {cos_r01:.6f}') + +# Check: is output just V (P=I case)? +cos_v = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), kvf.flatten().unsqueeze(0)).item() +print(f'Cosine with V alone: {cos_v:.6f}') + +# Print some output statistics +print(f'\nOutput stats: min={out.min().item():.4f}, max={out.max().item():.4f}, mean={out.mean().item():.4f}') +print(f'Ref stats: min={ref.min().item():.4f}, max={ref.max().item():.4f}, mean={ref.mean().item():.4f}') diff --git a/tests/test_packing_diag.py b/tests/test_packing_diag.py new file mode 100644 index 00000000..d9db61e5 --- /dev/null +++ b/tests/test_packing_diag.py @@ -0,0 +1,133 @@ +""" +BF16 Packing Diagnostic: Write specific F32 bit patterns to P TMEM via St32x32bOp. +Then run PV MMA with V=identity. Output = P (as MMA reads it). + +This reveals the BF16 packing order within F32 TMEM words. + +Strategy: +- Skip MMA1 entirely (no QK computation) +- Write packed F32 values where low BF16 ≠ high BF16 +- Use K=V=identity so output[i,j] = P[i,j] +- Compare output against both packing orders +""" +import torch, struct, sys +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests') + +# Use the existing diag test that writes P=all-ones, but modify to write specific patterns +# The diag test is test_stage_b_diag.py - let me copy and modify v7 instead + +# Actually, let me just use the EXISTING kernel (test_stage_b_v7) with: +# 1. K = identity matrix +# 2. The identity softmax running normally (it writes softmax(Q@K^T) as P) +# 3. If Q is also identity, then Q@K^T = I@I = I, softmax(I) = softmax of identity +# 4. That's not what we want either. + +# Better: use the diag test (writes P=all-ones) with K=identity +# Output should be all-ones * I = all-ones. That gives no new info. + +# The REAL test: modify the kernel to write specific F32 values to the BF16 recast view +# Instead of writing 1.0 BF16, write alternating different values + +# Simplest approach: modify test_stage_b_diag.py to write j+1 in BF16 for each position +# Then with V=identity, output[i,j] = j+1 + +# But modifying the kernel requires JIT changes. Let me use the existing v7 kernel +# with a clever input choice instead. + +# Approach: Use the identity softmax (which writes Q@K^T scores as P) +# With Q=randn and K=identity: Q@K^T = Q (since K=I) +# Then P = softmax(Q) and output = softmax(Q) @ V +# With V=I: output = softmax(Q) +# This tests the FULL pipeline but doesn't isolate packing. + +# Let me just write the packing test kernel from scratch, minimally. +# It only needs: TMA load V, write F32 to P TMEM, PV MMA, epilogue. + +# Actually, the simplest isolation test: +# Use the existing test with P=all-ones and V=K (K=randn) +# We already know cos=0.08 for this. The 0.08 is not 0 or 1. +# If I use V where V[j] = 1 if j even, 0 if j odd, then: +# With correct packing: output = (number of ones in even K positions) per row +# This doesn't help because P=all-ones means all K positions are 1.0 + +# I think the key issue is that with P=all-ones (cos=0.08), the output should be +# EXACTLY sum(V, dim=0) for each row. Let me compare more carefully. + +# Let me run the EXISTING P=all-ones diag test and compare the output values +# against the reference. The PATTERN of errors will tell us about packing. + +print("Running existing P=all-ones diag with K=V=randn") +print("Comparing output vs reference to identify the error pattern") + +import cutlass.cute as cute +import cutlass.utils as utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass import Float32, BFloat16 +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cutlass.torch as ct +import cuda.bindings.driver as cuda + +torch.manual_seed(42) +m, n, k = 128, 128, 128 +q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') +kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') +c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + +kvf = kv[:,:,0].float() +ref = torch.ones(128, 128, dtype=torch.float32, device='cuda') @ kvf + +mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) +mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) +mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) +stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + +from test_stage_b_diag import StageBDiag +kernel = StageBDiag(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) +print('Compiling diag test...', flush=True) +compiled = cute.compile(kernel, mQ, mK, mC, stream) +print('Running...', flush=True) +compiled(mQ, mK, mC, stream) +torch.cuda.synchronize() + +out = c[:,:,0].float() +cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + +print(f'\nCosine: {cos:.6f}') +print(f'Output row 0[:8]: {out[0,:8].tolist()}') +print(f'Ref row 0[:8]: {ref[0,:8].tolist()}') +print(f'Ratio out/ref row 0[:8]: {(out[0,:8] / ref[0,:8]).tolist()}') + +# Check if output is a scaled version of the reference +ratio = out[0] / ref[0] +ratio_clean = ratio[torch.isfinite(ratio)] +print(f'Ratio mean: {ratio_clean.mean().item():.6f}, std: {ratio_clean.std().item():.6f}') + +# Check if odd/even columns have different ratios +ratio_even = (out[0, 0::2] / ref[0, 0::2])[torch.isfinite(out[0, 0::2] / ref[0, 0::2])] +ratio_odd = (out[0, 1::2] / ref[0, 1::2])[torch.isfinite(out[0, 1::2] / ref[0, 1::2])] +print(f'Even col ratio mean: {ratio_even.mean().item():.6f}') +print(f'Odd col ratio mean: {ratio_odd.mean().item():.6f}') + +# Check if output matches a different V interpretation +# What if the V SMEM is being read in a different column order? +# Compute: what if V columns were permuted? +# With P=all-ones, output[i] = sum(V[:,j]) for each j +# This is the column sum of V, broadcast to all rows +# If V is read in a different order, the output would be the sum of +# differently-ordered V columns, but still all the same sum +# So output should be uniform across columns = sum of all V values per column +# But the output IS uniform (all rows identical). The VALUES are wrong. +# That means the SUM is wrong, which means the V values being read are wrong. + +# Let me check: does the output column structure match any known transform of V? +print(f'\nV (K) row 0[:8]: {kvf[0,:8].tolist()}') +print(f'V col sums (reference): {ref[0,:4].tolist()}') +print(f'Output col 0[:4]: {out[0,:4].tolist()}') + +# What if V is being read transposed? +# sum(V[j,:]) per column j = row sums of V +row_sums = kvf.sum(dim=1) +col_sums = kvf.sum(dim=0) +print(f'V row sums[:4]: {row_sums[:4].tolist()}') +print(f'V col sums[:4]: {col_sums[:4].tolist()}') diff --git a/tests/test_pair_swap.py b/tests/test_pair_swap.py new file mode 100644 index 00000000..68a2041a --- /dev/null +++ b/tests/test_pair_swap.py @@ -0,0 +1,119 @@ +""" +BF16 Pair-Swap Test: compare kernel output against reference with V rows +swapped in even/odd pairs (simulating BF16 packing swap within F32 words). +""" +import torch, sys +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests') + +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from test_stage_b_v7 import StageBIdentitySoftmax + +torch.manual_seed(42) +m, n, k = 128, 128, 128 +q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') +kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') +c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + +qf = q[:,:,0].float() +kvf = kv[:,:,0].float() + +# Standard reference: P @ V where P = Q @ K^T +ref = qf @ kvf.T @ kvf + +# Pair-swapped reference: if BF16 within each F32 are swapped, +# MMA reads K[2j] and K[2j+1] swapped, which means V rows are swapped in pairs +# Output = P @ V_with_row_pairs_swapped +V_swap = kvf.clone() +V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone() +ref_swap = qf @ kvf.T @ V_swap + +# Also try: just the P@K^T with V where every 2 consecutive rows are swapped +# but starting from different offsets +# Try 1-based offset: swap rows (1,2), (3,4), ... +V_swap1 = kvf.clone() +V_swap1[1::2], V_swap1[2::2] = kvf[2::2].clone(), kvf[1::2].clone() +V_swap1[0] = kvf[0] # row 0 unchanged +ref_swap1 = qf @ kvf.T @ V_swap1 + +mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) +mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) +mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) +stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + +kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) +print('Compiling...', flush=True) +compiled = cute.compile(kernel, mQ, mK, mC, stream) +print('Running...', flush=True) +compiled(mQ, mK, mC, stream) +torch.cuda.synchronize() + +out = c[:,:,0].float() +cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() +cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item() +cos_swap1 = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap1.flatten().unsqueeze(0)).item() + +print(f'\nCosine with standard ref: {cos_ref:.6f}') +print(f'Cosine with pair-swapped ref: {cos_swap:.6f}') +print(f'Cosine with offset-1 swapped: {cos_swap1:.6f}') + +# Also try: what if the entire P row is shifted by some amount? +# Or what if P columns are reordered according to the A-fragment's K partitioning? +# The A-fragment has K laid out as (16 inner, 4 outer chunks of 16) +# If the store writes columns sequentially but the MMA reads them in the +# chunked order, we'd get a specific permutation + +# The A-fragment K layout: (col, k0, k1) where col=0..15, k0=0..3, k1=0..1 +# K position = col + 16*k0 + 64*k1 +# If the store writes K sequentially (0,1,2,...,127) but the MMA reads +# in the fragment order, the mapping would be: +# Fragment index (col, k0, k1) -> K position (col + 16*k0 + 64*k1) +# But this is sequential (0,1,2,...), so no permutation. +# UNLESS the store doesn't fill the fragment's K blocks correctly. + +# Try: what if only the FIRST 64 K values (not 128) are filled? +# The store writes 64 F32 = 128 BF16. But what if the MMA's A-fragment +# K dimension is 64 (not 128)? Then only 64 K values have P, rest are garbage. +# But we already verified nblk_pv=4 and the fragment covers 128 BF16. + +# Try: what if the 64 F32 columns in the store map to 128 BF16 K positions +# but with the packing where BF16[K] and BF16[K+1] come from the same F32 word +# and K is determined by the A-fragment's partition order? + +# The A-fragment reads K in groups of 16 (inner K), then 4 outer groups, then 2 BF16 per F32 +# So K values are accessed as: [group0_bf16_0, group0_bf16_1, group1_bf16_0, group1_bf16_1, ...] +# where group = (col, k0, k1) and bf16_0/bf16_1 are the two BF16 in each F32 word + +# This is getting complex. Let me just check the specific permutation by +# matching output columns to reference columns for row 0. + +# For each output column j, find which reference column it matches +# (using the entire row as a signature) +print('\nTrying to identify the column permutation for row 0...') +# Since all values might not be unique, use the dot product with a known vector +# Actually, the simplest: for output[0, j], check if it equals ref[0, k] for any k +# But we need a unique signature. Let's use the output of the ENTIRE row. + +# Compare output row i with reference rows +for i in [0, 1, 2]: + best_cos = 0 + best_j = -1 + for j in range(128): + c = torch.nn.functional.cosine_similarity(out[i].unsqueeze(0), ref[j].unsqueeze(0)).item() + if c > best_cos: + best_cos = c + best_j = j + print(f' Output row {i} best matches ref row {best_j} (cos={best_cos:.4f})') + +# Also: compare output column j with reference columns +print('\nColumn matching (output col j vs ref col k):') +for j in [0, 1, 2, 3, 4, 5, 6, 7]: + best_cos = 0 + best_k = -1 + for k in range(128): + c = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item() + if c > best_cos: + best_cos = c + best_k = k + print(f' Output col {j} best matches ref col {best_k} (cos={best_cos:.4f})') diff --git a/tests/test_pair_swap2.py b/tests/test_pair_swap2.py new file mode 100644 index 00000000..2f4d7a5d --- /dev/null +++ b/tests/test_pair_swap2.py @@ -0,0 +1,95 @@ +""" +BF16 Pair-Swap Test: compare kernel output against reference with V rows +swapped in even/odd pairs (simulating BF16 packing swap within F32 words). +Also tries to identify the exact column permutation. +""" +import torch, sys +sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests') + +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from test_stage_b_v7 import StageBIdentitySoftmax + +torch.manual_seed(42) +m, n, k = 128, 128, 128 +q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') +kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') +c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + +qf = q[:,:,0].float() +kvf = kv[:,:,0].float() +ref = qf @ kvf.T @ kvf + +# Pair-swapped: V rows 0↔1, 2↔3, 4↔5, ... +V_swap = kvf.clone() +V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone() +ref_swap = qf @ kvf.T @ V_swap + +mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) +mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) +mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) +stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + +kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) +print('Compiling...', flush=True) +compiled = cute.compile(kernel, mQ, mK, mC, stream) +print('Running...', flush=True) +compiled(mQ, mK, mC, stream) +torch.cuda.synchronize() + +out = c[:,:,0].float() +cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() +cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item() + +print(f'\nCosine with standard ref: {cos_ref:.6f}') +print(f'Cosine with pair-swapped ref: {cos_swap:.6f}') + +# Column permutation identification +print('\nColumn permutation (output col j matches ref col k):') +perm = [] +for j in range(min(16, 128)): + best_cos = 0 + best_k = -1 + for k in range(128): + c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item() + if c_val > best_cos: + best_cos = c_val + best_k = k + perm.append(best_k) + if j < 16: + print(f' out col {j} -> ref col {best_k} (cos={best_cos:.4f})') + +# Check if permutation has a pattern +print(f'\nPermutation (first 16): {perm}') +diffs = [perm[j+1] - perm[j] for j in range(len(perm)-1)] +print(f'Permutation diffs: {diffs}') + +# Check: is perm[j] = j with every pair swapped? (0→1, 1→0, 2→3, 3→2, ...) +expected_pair_swap = [j^1 for j in range(128)] # XOR with 1 swaps even/odd +matches_pair_swap = all(perm[j] == expected_pair_swap[j] for j in range(len(perm))) +print(f'Matches pair-swap (0↔1, 2↔3, ...): {matches_pair_swap}') + +# Also check the full permutation for all 128 columns +full_perm = [] +for j in range(128): + best_cos = 0 + best_k = -1 + for k in range(128): + c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item() + if c_val > best_cos: + best_cos = c_val + best_k = k + full_perm.append(best_k) + +print(f'\nFull permutation: {full_perm}') +# Check if it's pair-swap for all columns +all_pair_swap = all(full_perm[j] == (j^1) for j in range(128)) +print(f'All columns match pair-swap: {all_pair_swap}') + +# If not pair-swap, what's the pattern? +if not all_pair_swap: + # Check if it's a 16-element block permutation + for block in range(8): + block_perm = full_perm[block*16:(block+1)*16] + print(f' Block {block}: {block_perm}') diff --git a/tests/test_recast_minimal.py b/tests/test_recast_minimal.py new file mode 100644 index 00000000..3a20d711 --- /dev/null +++ b/tests/test_recast_minimal.py @@ -0,0 +1,237 @@ +"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1. +No recast, no BF16, no packing. Pure FP32 copy between TMEM regions.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class RecastMinimal: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols * 2 + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) + + # LD and ST on same layout + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + tStS = thr_st.partition_D(tStS1) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + tStcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_afrag.py b/tests/test_stage_b_afrag.py new file mode 100644 index 00000000..8cce4ad8 --- /dev/null +++ b/tests/test_stage_b_afrag.py @@ -0,0 +1,210 @@ +"""Stage B: Store P via A-fragment layout, not C-fragment. + +Key insight from Mike: the C-fragment store and A-fragment read use different +physical TMEM address mappings. The fix is to construct the TMEM store +using the A-fragment layout (from p_tmem_s / make_fragment_A). + +Steps: +1. Q@K^T → TMEM (C-fragment, offset 0) +2. ld scores from TMEM (C-fragment) +3. Convert FP32→BF16 +4. st P to TMEM using A-fragment layout (p_tmem_s / tOrP0) +5. PV MMA reads from TMEM using A-fragment (same layout = same physical addresses) +6. Epilogue writes output +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StageBAfrag: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO) + # TMEM layout: S0 (scores) at offset 0, P (A-fragment) at offset 32, O at offset s_cols + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = self.s_cols + self.tmem_alloc_cols = self.s_cols + self.o_cols + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + # TMEM tensors + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + # P A-fragment (matching fmha exactly) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + # ── TMEM LOAD from C-fragment ── + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + # ── TMEM STORE to A-fragment layout ── + # Use St16x128bOp with BF16 for the A-fragment layout + # (St32x32bOp only works with C-fragment layout, not A-fragment) + tmem_st = cute.make_copy_atom(tcgen05.copy.St16x128bOp(tcgen05.copy.Repetition(16)), self.q_dtype) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tP) + thr_st = tiled_st.get_slice(sfw) + tStP = thr_st.partition_D(tP) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + # TMA + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_afrag2.py b/tests/test_stage_b_afrag2.py new file mode 100644 index 00000000..74a2a39f --- /dev/null +++ b/tests/test_stage_b_afrag2.py @@ -0,0 +1,217 @@ +"""Stage B: Store P via A-fragment layout with recast C-fragment iterator. + +Matching the backward FMHA pattern exactly: +1. tOrP = pv_thr.make_fragment_A(tP)[None,None,None,0] (A-fragment layout) +2. tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=BF16) (C-fragment base, recast to BF16) +3. tdVrP = cute.make_tensor(tdVrP_iter + offset, tOrP.layout) +4. make_tmem_copy(St32x32bOp(Repetition(8)), BF16, tdVrP) +5. Store BF16 registers to tdVrP +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StageBAfrag2: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO) + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 0 + self.tmem_o0_offset = self.s_cols * 2 + self.tmem_alloc_cols = 512 + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + # ── P A-fragment (backward FMHA pattern) ── + # 1. Get A-fragment layout from pv_mma + tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) + tP = cute.make_tensor(tP_iter, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + # 2. Recast C-fragment iterator to BF16 (matching backward FMHA line 962) + tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) + # 3. Create store target with A-fragment layout + recast iterator + # The offset for P within TMEM: qk_acc_dtype.width / q_dtype.width * tmem_p0_offset + # But since we recast to BF16, the offset should be in BF16 units + tdVrP = cute.make_tensor( + tdVrP_iter + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + # PV MMA's A-fragment (for reading) + tOrP0 = cute.make_tensor(tOrP.iterator, tOrP.layout) + # ── TMEM LOAD from C-fragment ── + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + # ── TMEM STORE via A-fragment layout (backward FMHA pattern) ── + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP) + thr_st = tiled_st.get_slice(sfw) + tStP = thr_st.partition_D(tdVrP) + # Source identity for store (A-fragment shape) + cS_P = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[2])) + tScS_P = pv_thr.partition_A(cS_P) + tStcS = thr_st.partition_S(tScS_P) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + print(f'[A2] tdVrP.layout: {tdVrP.layout}') + print(f'[A2] tOrP0.layout: {tOrP0.layout}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + # TMA + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_diag.py b/tests/test_stage_b_diag.py new file mode 100644 index 00000000..a2fb975d --- /dev/null +++ b/tests/test_stage_b_diag.py @@ -0,0 +1,384 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + print(' Output row 0[:8]:', out[0,:8].tolist()) + print(' Output row 1[:8]:', out[1,:8].tolist()) + print(' Output row 63[:8]:', out[63,:8].tolist()) + print(' Ref row 0[:8]:', ref[0,:8].tolist()) + print(' Output col 0[:8]:', out[:8,0].tolist()) + print(' Any zeros:', (out==0).sum().item(), 'of', out.numel()) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_final.py b/tests/test_stage_b_final.py new file mode 100644 index 00000000..01f7ca54 --- /dev/null +++ b/tests/test_stage_b_final.py @@ -0,0 +1,207 @@ +"""Stage B final: ld FP32 from S0, BF16 recast, st to S1 (offset s_cols), +PV MMA reads A-fragment from S1 at the appropriate offset. + +The recast pattern WORKS when writing to a different TMEM region. +PV MMA A-fragment reads from S1 with tmem_p0_offset adjustment. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StageBFinal: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO) + # S0 at offset 0, P0 (BF16 view) at offset 32 within S0, O0 at offset s_cols + # But for the recast test, we write the FULL C-fragment to S1 (offset s_cols) + # PV MMA A-fragment needs to read from S1. The A-fragment offset is: + # tOrP0 starts at tStS.iterator + (F32_width/BF16_width) * tmem_p0_offset + # For S1, we need A-fragment pointing to the start of S1. + # A-fragment offset for column c: width_ratio * c + # S1 starts at s_cols (128 FP32 columns). A-fragment for P at S1 start: + # offset = (F32_width / BF16_width) * s_cols = 2 * 128 = 256 + # But tOrP layout has stride 64 for M, so base offset is in BF16 column units + # Actually, tOrP0.iterator is the base of the A-fragment in the TMEM address space. + # The A-fragment offset is (qk_acc_dtype.width / q_dtype.width) * tmem_p0_offset. + # For our case, we want P at S1 start, so tmem_p0_offset should map to s_cols. + # But tmem_p0_offset is in FP32 columns, and the A-fragment width ratio converts it. + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 0 # P starts at the beginning of S1 + self.tmem_o0_offset = self.s_cols # O region after S1 + self.tmem_alloc_cols = self.s_cols + self.o_cols # S1 + O + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]); tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # S0: MMA output + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) # S1: BF16 copy of scores (P) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + # P A-fragment pointing to S1 (offset s_cols) + tP = cute.make_tensor(tStS.iterator + self.s_cols, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + # Since P is at the START of S1 (offset s_cols), no additional offset needed + tOrP0 = cute.make_tensor(tOrP.iterator, tOrP.layout) + # TMEM ld/st atoms + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0); tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1) + sfw = tidx % (32 * len(self.epilogue_warp_id)); thr_ld = tiled_ld.get_slice(sfw); thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0); tStS1_dst = thr_st.partition_D(tStS1) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])); tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS); tStcS = thr_st.partition_S(tScS) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + # TMA + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v10.py b/tests/test_stage_b_v10.py new file mode 100644 index 00000000..83169757 --- /dev/null +++ b/tests/test_stage_b_v10.py @@ -0,0 +1,321 @@ +"""Stage B v10: Use FP32 ld→st on FULL (128,128) layout, not subview. +The ld reads S0 (full C-fragment), the st writes S1 (full C-fragment at offset 128), +then PV MMA reads S1 via A-fragment. + +Key insight: The FP32 ld→st on the SAME layout works (cosine 0.999999). +The BF16 recast pattern with DIFFERENT layout sizes is broken. +For identity softmax, we need to write the data back to TMEM in a format +that the PV MMA's A-fragment can read. Since the C-fragment and A-fragment +both access the same physical TMEM, writing via C-fragment (store) and +reading via A-fragment (PV MMA) should work IF the data is in the right place. + +The pv_mma_tiler has N=128 (from mma_tiler_mn[1]) but P's K dimension = 128. +Wait, pv_mma_tiler = (M, K, K_inner*4). For the PV MMA, A=P with K=128 (full KV dim). +But p_tmem_s is make_smem_layout_a(pv_mma, pv_mma_tiler, BF16, 1) which gives the +A-fragment layout for (128, 128, 64). The A-fragment covers K=64 at CTA level... +Hmm, pv_mma_tiler[1] should be the V dimension (N in MMA terms), not K. + +Actually, for PV MMA: P (A operand, M×K) × V (B operand, K×N) → O (C operand, M×N) +With a_source=TMEM, the P operand's K dimension matches the CTA tile K. +pv_mma_tiler = (128, 128, 64): M=128, N=128, K=64. +So P is M×K = 128×64. That's the A-fragment shape. + +So the A-fragment only reads 64 columns of TMEM (K=64 in 4 blocks of 16). +The C-fragment for S has 128 columns. The P region starts at offset 32 (tmem_p0_offset). +With s_cols=128 and tmem_p0_offset=32, the P data starts at column 32 and spans 64 columns +(columns 32-95), which is within the S region (columns 0-127). + +So the store should write to columns 32-95 of the S region. The C-fragment composition +(128, tilePlikeFP32=64) with offset 32 does exactly this. + +But we showed the subview store + recast doesn't work. Let me try: write the FULL +C-fragment (128×128) at offset 128 (tStS1), and adjust the PV MMA's A-fragment +offset so it reads from the right columns within that region. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StageBv10: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + self.o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = self.s_cols + self.tmem_alloc_cols = self.s_cols + self.o_cols # 256 + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + cute.nvgpu.OperandMajorMode.K, + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P A-fragment + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # LD and ST copy atoms on the SAME layout (full C-fragment) + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS0) # SAME layout for ld and st + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + tStS_dst = thr_st.partition_D(tStS0) # St writes to S0 (same as MMA output) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + tStcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # TMA + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v11.py b/tests/test_stage_b_v11.py new file mode 100644 index 00000000..25cb4445 --- /dev/null +++ b/tests/test_stage_b_v11.py @@ -0,0 +1,210 @@ +"""Stage B v11: Backward FMHA pattern exactly. + +1. ld FP32 from S0 (C-fragment) +2. Quantize FP32→BF16 (same register shape, .load()/.store()) +3. Reshape BF16 register to store partition shape +4. St32x32bOp BF16 store to tdVrP (A-fragment layout, offset s_cols) +5. PV MMA reads from tOrP0 (A-fragment, offset s_cols) +6. Epilogue writes output +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StageBv11: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO) + self.tmem_s0_offset = 0 + self.tmem_o0_offset = self.s_cols * 2 # After S0 (128) + P (128 BF16 = 64 FP32, but A-frag offset uses width ratio) + self.tmem_alloc_cols = 512 # Enough for S0 + P + O + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + # ── P A-fragment (backward FMHA pattern) ── + tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype) + tP = cute.make_tensor(tP_iter, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + # Apply offset for P region (after S0) + p_bf16_offset = self.qk_acc_dtype.width // self.q_dtype.width * self.s_cols + tdVrP = cute.make_tensor(tOrP.iterator + p_bf16_offset, tOrP.layout) + tOrP0 = cute.make_tensor(tOrP.iterator + p_bf16_offset, tOrP.layout) + # ── TMEM LOAD from C-fragment ── + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + # ── TMEM STORE via A-fragment layout ── + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP) + thr_st = tiled_st.get_slice(sfw) + tStP = thr_st.partition_D(tdVrP) + tdVcST = pv_thr.partition_A(cS_id) # A-operand partition of identity + tStcS = thr_st.partition_S(tdVcST) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + # TMA + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v11b.py b/tests/test_stage_b_v11b.py new file mode 100644 index 00000000..91cd25ae --- /dev/null +++ b/tests/test_stage_b_v11b.py @@ -0,0 +1,398 @@ +""" +Stage B v11b: Identity Softmax - store to composition(tP, (128,128)) + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + # FIX: PV tiler swaps N and K from QK (fmha.py: pv_mma_tiler = (M, QK_K, QK_N)) + # P is (M, QK_N) and V is (QK_N, D), so PV MMA has K=QK_N, N=D + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v12.py b/tests/test_stage_b_v12.py new file mode 100644 index 00000000..e392fa3d --- /dev/null +++ b/tests/test_stage_b_v12.py @@ -0,0 +1,408 @@ +""" +Stage B v12: BF16 St16x128bOp store to tP A-layout, P=all-ones + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + # FIX: PV tiler swaps N and K from QK (fmha.py: pv_mma_tiler = (M, QK_K, QK_N)) + # P is (M, QK_N) and V is (QK_N, D), so PV MMA has K=QK_N, N=D + self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1]) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v7.py b/tests/test_stage_b_v7.py index ae41ac6d..a523ee6e 100644 --- a/tests/test_stage_b_v7.py +++ b/tests/test_stage_b_v7.py @@ -41,6 +41,11 @@ class StageBIdentitySoftmax: pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") self.cta_tile_shape_mnk = ( self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], @@ -80,7 +85,7 @@ class StageBIdentitySoftmax: # the first 32 C-layout columns are "dead space" in the A-layout mapping. # self.tmem_s0_offset = 0 - self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_p0_offset = 32 # Original self.tmem_o0_offset = s_cols # 128 self.tmem_alloc_cols = s_cols + o_cols # 256 @@ -117,6 +122,16 @@ class StageBIdentitySoftmax: pv_mma = utils.sm100.make_trivial_tiled_mma( self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + # Introspect PV MMA atom + print(f"[ATOM] PV MMA type: {type(pv_mma)}") + print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}") + print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}") + print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}") + print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}") + # Check a_src + print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}") + print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}") + print(f"[ATOM] PV MMA op: {pv_mma.op}") self._setup(qk_mma, pv_mma) a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) @@ -205,6 +220,10 @@ class StageBIdentitySoftmax: tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) @@ -220,16 +239,73 @@ class StageBIdentitySoftmax: # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') tOrP_base = pv_thr.make_fragment_A(tP) tOrP = tOrP_base[(None, None, None, 0)] tOrP0 = cute.make_tensor( tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, tOrP.layout) + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ── TMA WARP ── diff --git a/tests/test_stage_b_v7_rep128.py b/tests/test_stage_b_v7_rep128.py new file mode 100644 index 00000000..8f0d3e15 --- /dev/null +++ b/tests/test_stage_b_v7_rep128.py @@ -0,0 +1,445 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + # Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s) + print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}') + print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}') + print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}') + print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v7_rep16.py b/tests/test_stage_b_v7_rep16.py new file mode 100644 index 00000000..8f0d3e15 --- /dev/null +++ b/tests/test_stage_b_v7_rep16.py @@ -0,0 +1,445 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + # Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s) + print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}') + print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}') + print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}') + print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v7_rep64.py b/tests/test_stage_b_v7_rep64.py new file mode 100644 index 00000000..8f0d3e15 --- /dev/null +++ b/tests/test_stage_b_v7_rep64.py @@ -0,0 +1,445 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + # Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s) + print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}') + print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}') + print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}') + print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v7_rep8.py b/tests/test_stage_b_v7_rep8.py new file mode 100644 index 00000000..8f0d3e15 --- /dev/null +++ b/tests/test_stage_b_v7_rep8.py @@ -0,0 +1,445 @@ +""" +Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets. + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}") + print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}") + print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}") + print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}") + print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}") + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + print(f"[DIAG] tCrV.size = {cute.size(tCrV)}") + print(f"[DIAG] tCrA.size = {cute.size(tCrA)}") + print(f"[DIAG] tCrB.size = {cute.size(tCrB)}") + print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}") + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + print(f'[TMEM] p_tmem_s: {p_tmem_s}') + print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}') + # Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s) + print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}') + print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}') + print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}') + print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}') + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + print(f'[DIAG] tStS.layout: {tStS.layout}') + print(f'[DIAG] tStS.size: {cute.size(tStS)}') + print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}') + print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}') + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # Compute nblk_pv for diagnostics + nblk_pv = cute.size(tOrP0, mode=[2]) + nblk_qk = cute.size(tCrA, mode=[2]) + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + # COMPREHENSIVE LAYOUT DUMP + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32))) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + + print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}') + print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}') + print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}') + print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}') + print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}') + print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}') + print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}') + print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}') + print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}') + print(f'[LAYOUT] tP.layout: {tP.layout}') + print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}') + print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}') + print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}') + print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}') + print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}') + print(f'[LAYOUT] tOtO.layout: {tOtO.layout}') + print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}') + print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}') + print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}') + print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}') + + # DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition) + tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout) + print(f'[DIAG] tP.layout: {tP.layout}') + print(f'[DIAG] tP.size: {cute.size(tP)}') + print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}') + print(f'[DIAG] tStS_P.layout: {tStS_P.layout}') + print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}') + print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}') + print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}') + print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}') + + print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}') + print(f'[DIAG] tCrV.size = {cute.size(tCrV)}') + print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}') + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v8.py b/tests/test_stage_b_v8.py new file mode 100644 index 00000000..e3c7f372 --- /dev/null +++ b/tests/test_stage_b_v8.py @@ -0,0 +1,403 @@ +""" +Stage B v8: Fix the identity softmax by using the A-fragment layout +for the TMEM store target instead of the C-fragment composition. + +The bug in v7: tStS_P uses composition(tStS.layout, (128, tilePlikeFP32)) +which gives layout (128,64):(65536,1) — C-fragment strides. +But the PV MMA reads from TMEM using the A-fragment layout +((128,16),1,4):((64,1),0,16) — physical TMEM strides. + +For the store to be read correctly by the PV MMA, the store target +must use the same physical TMEM addressing as the A-fragment. + +Key insight from CUTLASS source: + Physical TMEM for M=128, BK=64, BF16 A from TMEM (K-major): + tmem[dp=m, col=base_col + 16*mma_k + k_inner] for mma_k in 0..3, k_inner in 0..15 + + This means: A-fragment address = 64*m + k_inner + 16*mma_k + C-fragment address for (m, col) = ??? (virtual layout, not physical) + + The St32x32b copy atom with tStS_P (C-composition) writes to C-layout addresses. + The PV MMA reads from A-layout addresses. These are different physical locations. + +Fix: Use tP (from p_tmem_s, the A-fragment source layout) as the store target +instead of tStS_P (the C-fragment composition). This ensures the store writes +to physical TMEM addresses that the PV MMA's A-fragment will read correctly. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmaxV8: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + # TMEM tensors + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P tensor for PV MMA A-fragment + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # ── KEY FIX: Store target uses A-fragment layout (tP) not C-fragment composition ── + # The store must write to physical TMEM addresses that the PV MMA reads via A-fragment. + # tP has layout ((128,16),1,4,1):((64,1),0,16,0) — the A-fragment's physical TMEM layout. + # We need the store target at tmem_p0_offset = 32 columns into the S region. + # tP's iterator starts at tStS.iterator (base of S region). + # tOrP0 starts at tStS.iterator + 2*32 (scaled by F32/BF16 width ratio for the A-fragment). + # The store target should use the SAME layout as tP but with the p0 offset applied. + tP_store = cute.make_tensor( + tStS.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + p_tmem_s.outer) + + print(f'[v8] tP.layout: {tP.layout}') + print(f'[v8] tP_store.layout: {tP_store.layout}') + print(f'[v8] tOrP0.layout: {tOrP0.layout}') + print(f'[v8] tStS.layout: {tStS.layout}') + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1 BF16 + tTMEM_STORErP = cute.make_rmem_tensor(tScS_A.shape, self.qk_acc_dtype) + tTMEM_STORErP_e = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErP.iterator, dtype=self.q_dtype), + tTMEM_LOADrS.layout) + s_vec = tTMEM_LOADrS.load() + tTMEM_STORErP_e.store(s_vec.to(self.q_dtype)) + + # Store to A-layout TMEM + cute.copy(tiled_tmem_store, tTMEM_STORErP, tTMEM_STOREtP) + cute.arch.fence_view_async_tmem_store() + + si_handle.release() + + # ── Epilogue ── + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage) + c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) + c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) + acc_cons_st = utils.gemm.sm100.epilogue_tma_store( + self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, + epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe) + c_pipe.producer_tail() + tmem.relinquish_alloc_permit() + tmem.free(tmem_ptr) + + +def test(): + torch.manual_seed(42) + m, n, k = 128, 128, 128 + q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float() + ref = qf @ kvf.T @ kvf + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = StageBIdentitySoftmaxV8(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True) + print('Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mC, stream) + print('Running...', flush=True) + compiled(mQ, mK, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print('Stage B v8: (Q @ K^T) @ V with identity softmax (A-layout store)') + print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err)) + print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v8b.py b/tests/test_stage_b_v8b.py new file mode 100644 index 00000000..67b8fd89 --- /dev/null +++ b/tests/test_stage_b_v8b.py @@ -0,0 +1,354 @@ +""" +Stage B v8b: BF16 store directly to tOrP0 (diagnostic) + +Key fixes over v6: + - TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols) + - P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly) + - tilePlikeFP32 computed from qk_mma_tiler and dtype widths + - tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments + - JIT-time diagnostic prints of all TMEM sizes + +Architecture (matches fmha.py exactly): + MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False) + Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout + MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True) +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + + +class StageBIdentitySoftmax: + def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True): + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3 + self.num_c_stage = 2 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], + self.qk_mma_tiler[2], + ) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) + + # ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ── + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + # tilePlikeFP32 for the store-side composition + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + + # ── TMEM LAYOUT (matching fmha.py) ── + # P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout) + # in the same TMEM region. The A-layout view starts partway through the S region. + # fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered) + # The P offset of 32 means the A-layout starts at column 32 within the S region. + # This is because the C-layout and A-layout partition TMEM differently per-thread; + # the first 32 C-layout columns are "dead space" in the A-layout mapping. + # + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 # Same as fmha.py + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + # Also compute via get_num_tmem_alloc_cols for the full allocation + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + print(f"[StageB] s_cols (QK accumulator) = {s_cols}") + print(f"[StageB] o_cols (PV accumulator) = {o_cols}") + print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}") + print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}") + print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}") + print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}") + print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}") + print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type + self.a_major = LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, self.a_major, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major, + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + # V shares the same SMEM as B (same data, different layout for PV MMA) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout + + # ── TMEM tensors with computed offsets (matching fmha.py pattern) ── + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + # P fragment: construct from p_tmem_s layout (matching fmha.py exactly) + # fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + # tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + # tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP_base = pv_thr.make_fragment_A(tP) + tOrP = tOrP_base[(None, None, None, 0)] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_stage_b_v9.py b/tests/test_stage_b_v9.py new file mode 100644 index 00000000..e8ea6e94 --- /dev/null +++ b/tests/test_stage_b_v9.py @@ -0,0 +1,348 @@ +""" +Stage B v9: Identity softmax matching fmha.py softmax_step EXACTLY. + +Line by line match of the fmha softmax_step, but with identity softmax: + - No masking + - scale = 1 (no log2 scaling) + - row_max = 0 (skip max, just do exp(x) = 1 for identity) + - Actually for identity softmax, P = S (scores). So we just ld S, st as P. + - But we need to go through the C→A layout transform properly. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StageBv9: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2 + self.acc_dtype = Float32; self.epilog_sync_bar_id = 1 + self.use_2cta_instrs = False + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = 128 + + def _setup(self, qk_mma, pv_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2]) + self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.c_layout = c_layout + self.num_ab_stage = 1; self.num_acc_stage = 1 + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + o_cols = find_tmem_tensor_col_offset(tOtO) + + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + self.tmem_o0_offset = s_cols # 128 + self.tmem_alloc_cols = s_cols + o_cols # 256 + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1)) + self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + cute.nvgpu.OperandMajorMode.K, + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.TMEM) + self._setup(qk_mma, pv_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + use_2cta = cute.size(qk_mma.thr_id.shape) == 2 + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner) + sV = cute.make_tensor(sV_ptr, v_smem_s.outer) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + tCrV = pv_mma.make_fragment_B(sV) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + pv_thr = pv_mma.get_slice(0) + pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2]) + tOtO = pv_thr.make_fragment_C(pv_acc_shape) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + tOrP0 = cute.make_tensor( + tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout) + + # ── TMEM copy atoms (matching fmha softmax_step exactly) ── + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + thr_tmem_load = tiled_tmem_load.get_slice(tidx % (32 * len(self.epilogue_warp_id))) + tTMEM_LOADtS = thr_tmem_load.partition_S(tStS0) + + # Store target: composition of C-fragment (matching fmha) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, + cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) + thr_tmem_store = tiled_tmem_store.get_slice(tidx % (32 * len(self.epilogue_warp_id))) + tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) + + # Identity tensors + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tScS_P = cute.make_tensor(tScS.iterator, + cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))) + + tTMEM_LOADcS = thr_tmem_load.partition_D(tScS) + tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P) + + print(f'[v9] tTMEM_LOADcS.shape: {tTMEM_LOADcS.shape}') + print(f'[v9] tTMEM_STOREcS.shape: {tTMEM_STOREcS.shape}') + print(f'[v9] LOAD size: {cute.size(tTMEM_LOADcS)}') + print(f'[v9] STORE size: {cute.size(tTMEM_STOREcS)}') + print(f'[v9] tilePlikeFP32: {self.tilePlikeFP32}') + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) + tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_store_verify.py b/tests/test_store_verify.py new file mode 100644 index 00000000..85a0b41c --- /dev/null +++ b/tests/test_store_verify.py @@ -0,0 +1,285 @@ +""" +Test: Q@K^T → TMEM (scores), ld scores, st to P region, +then epilogue reads P region as C-fragment (not PV MMA). + +If the ld/st roundtrip preserves the data, the epilogue should +output the same as Stage A (Q@K^T result). +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StoreVerify: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + self.tmem_s0_offset = 0 + self.tmem_p0_offset = 32 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32 + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR + self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = s_cols # Only need scores region, no O region + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + # TMEM copy atoms for ld/st + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) + thr_tmem_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_tmem_load.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_tmem_load.partition_D(tScS) + + # Store target: P region (composition of C-fragment, offset 32) + tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, + cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) + thr_tmem_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtP = thr_tmem_store.partition_D(tStS_P) + tScS_P = cute.make_tensor(tScS.iterator, + cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))) + tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_store_verify2.py b/tests/test_store_verify2.py new file mode 100644 index 00000000..fe5a11a3 --- /dev/null +++ b/tests/test_store_verify2.py @@ -0,0 +1,272 @@ +"""Store verify v2: ld S (full 128x128), st P (full 128x128 at offset 128), +then epilogue reads P region. No subview, no composition.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class StoreVerify2: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = s_cols * 2 # Two full regions + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + s_cols = find_tmem_tensor_col_offset(tStS) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # offset 0 + tStS1 = cute.make_tensor(tStS.iterator + s_cols, tStS.layout) # offset 128 (second region) + + # Load from S0, Store to S1 (same layout, just different offset) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS = thr_load.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS1) + thr_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtS1 = thr_store.partition_D(tStS1) + tTMEM_STOREcS = thr_store.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_tmem_copy_roundtrip.py b/tests/test_tmem_copy_roundtrip.py new file mode 100644 index 00000000..3885f16c --- /dev/null +++ b/tests/test_tmem_copy_roundtrip.py @@ -0,0 +1,271 @@ +"""Minimal TMEM ld→st roundtrip. FP32 only, no BF16 cast. +Uses the fmha pattern: load to register, store via make_rmem_tensor + recast + load/store.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class TMEMCopyRoundtrip: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols * 2 + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) + + # Copy atoms + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0) + sfw_idx = tidx % (32 * len(self.epilogue_warp_id)) + thr_load = tiled_tmem_load.get_slice(sfw_idx) + tTMEM_LOADtS0 = thr_load.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LOADcS = thr_load.partition_D(tScS) + + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS1) + thr_store = tiled_tmem_store.get_slice(sfw_idx) + tTMEM_STOREtS1 = thr_store.partition_D(tStS1) + tTMEM_STOREcS = thr_store.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # ── TMA WARP ── + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_tmem_fp32_roundtrip.py b/tests/test_tmem_fp32_roundtrip.py new file mode 100644 index 00000000..0c2dd017 --- /dev/null +++ b/tests/test_tmem_fp32_roundtrip.py @@ -0,0 +1,255 @@ +"""Minimal: Q@K^T → TMEM, ld FP32 from S0, st FP32 to S1, epi reads S1. +NO bf16 cast at all. Pure FP32 ld→st roundtrip.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class FP32Roundtrip: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols * 2 + + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) + + # Single tiled copy that can both ld and st + tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st_atom, tStS1) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + + tTMEM_LDtS = thr_ld.partition_S(tStS0) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) + tTMEM_LDcS = thr_ld.partition_D(tScS) + tTMEM_STtS = thr_st.partition_D(tStS1) + tTMEM_STcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + # TMA + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test() diff --git a/tests/test_tmem_layout_diag.py b/tests/test_tmem_layout_diag.py new file mode 100644 index 00000000..aa6dc6e4 --- /dev/null +++ b/tests/test_tmem_layout_diag.py @@ -0,0 +1,168 @@ +""" +Diagnostic: understand the C-layout vs A-layout TMEM address mapping. + +After Stage A writes Q@K^T results to TMEM via C-fragment (tStS0), +read the same TMEM data back using the A-fragment (tOrP0) layout +and compare against reading it via the C-fragment layout. + +This will tell us whether the composition-based store target +produces the right physical TMEM addresses for the A-fragment to read. +""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +def test(): + torch.manual_seed(42) + m, n, k = 128, 128, 128 + q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda') + kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kvf = kv[:,:,0].float() + ref_qk = qf @ kvf.T # (128, 128) + + import cutlass.torch as ct + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + a_major = LayoutEnum.from_tensor(mQ).mma_major_mode() + b_major = LayoutEnum.from_tensor(mK).mma_major_mode() + c_layout = LayoutEnum.from_tensor(mC) + mma_tiler_mn = (128, 128) + mma_tiler = (*mma_tiler_mn, 1) + + qk_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, a_major, b_major, + Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma( + BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, b_major, + Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM) + + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4) + pv_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4) + + # Build the fragments + qk_thr = qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + + pv_thr = pv_mma.get_slice(0) + p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1) + tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) + tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0] + + tilePlikeFP32 = qk_mma_tiler[1] * BFloat16.width // 32 + tStS_P = cute.make_tensor(tStS.iterator + 32, cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))) + + # Print the key layouts + print(f'tStS.layout: {tStS.layout}') + print(f'tStS.size: {cute.size(tStS)}') + print(f'tOrP.layout: {tOrP.layout}') + print(f'tOrP.size: {cute.size(tOrP)}') + print(f'tStS_P.layout: {tStS_P.layout}') + print(f'tP.layout: {tP.layout}') + print(f'tilePlikeFP32: {tilePlikeFP32}') + + # For the C-fragment layout ((128,128),1,1):((65536,1),0,0) + # element at logical (m, n) maps to address 65536*m + n + # For the A-fragment layout ((128,16),1,4):((64,1),0,16) + # element at logical ((m_inner, k_inner), 1, mma_k) maps to address 64*m_inner + k_inner + 16*mma_k + + # In the C-fragment, logical (m, col) = (m, col) with stride (65536, 1) + # In the A-fragment, logical ((m, k16), 1, block4) with stride ((64, 1), 0, 16) + # So A[m, k16, block4] -> address 64*m + k16 + 16*block4 + # C[m, col] -> address 65536*m + col + + # The KEY question: for the same physical TMEM column c and row m, + # what does C-layout say the logical index is, vs A-layout? + + # C-layout: (m, col) -> addr = 65536*m + col + # A-layout: ((m, k16), 1, blk) -> addr = 64*m + k16 + 16*blk + + # If both refer to the same physical TMEM, then: + # 65536*m_c + col = 64*m_a + k16 + 16*blk + # These can only be equal if the address spaces are different (C uses a different base) + # OR if the C-layout addresses wrap around modulo the TMEM size + + # C-fragment cosize = 8323200 which is >> 16384 (actual data size) + # The C-fragment uses a STRIDED layout where stride-0 (M) = 65536 + # This means the C-fragment is NOT a dense layout in TMEM address space + # The actual TMEM addresses used by the MMA hardware follow a different mapping + + # Let's check: for M=128, N=128 C-fragment with ((128,128),1,1):((65536,1),0,0) + # The max address = 65536*127 + 127 = 8,321,023 + # But TMEM only has 1MB = 256 columns * 4096 rows + # So addresses 65536*m must be modded or the layout is virtual + + # AH. The C-fragment layout is VIRTUAL. cute.compile maps it to physical TMEM + # addresses when the actual MMA operation runs. The strides in the fragment layout + # are logical, not physical. The MMA hardware knows the physical column mapping. + + # For the store, make_tmem_copy(St32x32bOp, tStS_P) creates a copy plan. + # The copy writes to tStS_P's layout addresses. These ARE the physical TMEM addresses + # that the MMA hardware will read from for the A-fragment. + + # So the question becomes: is the COMPOSITION tStS_P correct? + # tStS.layout = ((128,128),1,1):((65536,1),0,0) + # composition with (128, 64) produces (128,64):(65536,1) + # This takes the first 64 columns of the C-layout. + + # tOrP.layout = ((128,16),1,4):((64,1),0,16) + # This is the A-layout for 4 K=16 blocks. + + # Let me compute what TMEM addresses each layout generates for the same (m, k) pair. + # For m in [0..127], for k in [0..63]: + # C-layout: logical (m, k) -> addr = 65536*m + k + # A-layout: k = 16*blk + k_inner + # logical ((m, k_inner), 1, blk) -> addr = 64*m + k_inner + 16*blk = 64*m + k + + # So C-layout says addr = 65536*m + k + # A-layout says addr = 64*m + k + # These are DIFFERENT for m > 0. + + # The C-fragment stride of 65536 is NOT the physical TMEM stride. + # Physical TMEM for M=128 has stride 64 (one 32B row per datapath, 64 FP32 columns) + # Wait, 128 BF16 = 256 bytes = 8 columns of 32B? No... + + # TMEM is organized as columns of 128 rows of 32 bits each. + # For FP32 accumulator: each column holds 128 FP32 values (one per DP lane) + # For BF16 view: each FP32 column = 2 BF16 values packed? Or separate columns? + + # The 64 stride in A-layout: 128 BF16 values need 64 FP32 columns (2 BF16 per FP32) + # stride-1 in the (128,16) sub-shape: (64,1) means m has stride 64, k_inner has stride 1 + # 64 FP32 columns for 128 BF16 rows makes sense: each column holds 2 BF16, so 128/2=64 columns + + # But wait, TMEM is 32-bit per entry per DP lane. For FP32, one column = 128 FP32 values. + # For BF16 A-fragment, each column stores 2 BF16 (high+low), so 128 BF16 = 64 columns. + # That matches the stride of 64 for m in the A-layout. + + # For the C-fragment with FP32 accumulator, each column = 128 FP32 values. + # The C-layout for (128,128) FP32: stride (65536, 1) + # 65536 = 512 * 128. That's weird. + # Actually 65536 = 2^16. For 128 rows in FP32, each row in a column needs... hmm. + # The C-fragment stores FP32 values. stride-0 = 65536 means elements in the M direction + # are 65536 addresses apart. But we have 128*128 = 16384 elements total. + # cosize = 8323200 >> 16384. The layout is very sparse. + + # I think the C-fragment layout addresses are NOT physical TMEM addresses. + # The MMA hardware and the ld/st copy atoms know how to map between logical and physical. + # The copy atom's partition_S/partition_D functions create the right physical mappings. + + # The real question for Stage B is simpler: + # After Q@K^T writes to tStS0 (C-fragment), and we ld the scores, + # apply identity softmax, and st them to tStS_P (composed C-fragment)... + # does the PV MMA read the same data from tOrP0 (A-fragment)? + + # fmha.py proves this works for F16. The composition pattern is identical. + # So either: (1) something about BF16 vs F16, or (2) something else in my code is wrong. + + print('\nDiagnostic complete. Need to check if BF16 vs F16 affects the TMEM layout mapping.') + +if __name__ == '__main__': + test() diff --git a/tests/test_tmem_pure_fp32.py b/tests/test_tmem_pure_fp32.py new file mode 100644 index 00000000..9bc9e98a --- /dev/null +++ b/tests/test_tmem_pure_fp32.py @@ -0,0 +1,237 @@ +"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1. +No recast, no BF16, no packing. Pure FP32 copy between TMEM regions.""" +import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import cuda.bindings.driver as cuda + +class PureFP32Copy: + def __init__(self, mma_tiler_mn): + self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16 + self.c_dtype = BFloat16; self.acc_dtype = Float32 + self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) + self.cluster_shape_mn = (1, 1) + self.cta_group = tcgen05.CtaGroup.ONE + self.epilogue_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4; self.tma_warp_id = 5 + self.threads_per_cta = 192 + self.num_c_stage = 2; self.use_2cta_instrs = False + self.epilog_sync_bar_id = 1 + + def _setup(self, qk_mma): + qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) + self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) + self.mma_tiler = self.qk_mma_tiler + self.cta_tile_shape_mnk = ( + self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), + self.qk_mma_tiler[1], self.qk_mma_tiler[2]) + self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) + self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) + self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) + c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout + self.epi_tile = utils.sm100.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, False, c_layout, self.o_dtype) + self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2) + self.num_ab_stage = 1; self.num_acc_stage = 1 + qk_thr = qk_mma.get_slice(0) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + self.s_cols = find_tmem_tensor_col_offset(tStS) + self.tmem_alloc_cols = self.s_cols * 2 + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + self.num_tma_load_bytes = ( + cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + ) * cute.size(qk_mma.thr_id.shape) + + @cute.jit + def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream): + qk_mma = utils.sm100.make_trivial_tiled_mma( + self.q_dtype, self.q_dtype, + LayoutEnum.from_tensor(a).mma_major_mode(), + LayoutEnum.from_tensor(b).mma_major_mode(), + self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, + tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) + b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) + tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A( + utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), + a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B( + utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), + b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) + epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) + tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) + self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, + self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile + ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, + a_smem_s, b_smem_s, c_smem_s, epi_tile): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c) + @cute.struct + class SS: + ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] + acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc: cutlass.Int64 + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + ab_p, ab_c = pipeline.PipelineTmaUmma.create( + barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True + ).make_participants() + acc_pipe = pipeline.PipelineUmmaAsync.create( + barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), + cta_layout_vmnk=cl_vmnk, defer_sync=True) + tmem_bar = pipeline.NamedBarrier(barrier_id=2, + num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, + allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, + two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) + pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) + sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) + sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) + sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) + gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None)) + gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None)) + gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None)) + k_cnt = cute.size(gA, mode=[3]) + qk_thr = qk_mma.get_slice(0) + tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC) + a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) + tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3)) + b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) + tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3)) + tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)] + tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB) + qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_acc_shape) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) + + # LD and ST on same layout + tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype) + tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0) + tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1) + sfw = tidx % (32 * len(self.epilogue_warp_id)) + thr_ld = tiled_ld.get_slice(sfw) + thr_st = tiled_st.get_slice(sfw) + tLdS = thr_ld.partition_S(tStS0) + tStS = thr_st.partition_D(tStS1) + cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS_id) + tLdcS = thr_ld.partition_D(tScS) + tStcS = thr_st.partition_S(tScS) + + tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1)) + pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) + + if warp_idx == self.tma_warp_id: + ab_p.reset(); peek = ab_p.try_acquire() + for kt in cutlass.range(k_cnt, unroll=1): + h = ab_p.acquire_and_advance(peek) + cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier) + cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier) + peek = cutlass.Boolean(1) + if h.count+1= 0.99 else 'FAIL')) + +if __name__ == '__main__': + test()