Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed

Key finding: C-fragment and A-fragment use different physical TMEM address
mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA
reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16.

Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999)
Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02)
Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
This commit is contained in:
2026-05-21 00:12:47 +00:00
parent 97656a5cd1
commit 467ade37b2
35 changed files with 8998 additions and 78 deletions

170
README.md
View File

@@ -18,15 +18,17 @@ cutedsl/
└── dense_blockscaled_gemm_persistent.py # REFERENCE: Blackwell TMEM/tcgen05 GEMM
tests/
├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM
├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + identity softmax (runs, wrong output)
├── test_stage_b_minimal.py # Stage B minimal: two MMAs, no softmax (runs, NaN expected)
├── test_stage_b_pipeline_only.py # ✅ Stage B pipeline-only: PipelineUmmaAsync, no ld/st (runs, NaN expected)
├── diag_tmem.py # Diagnostic: TMEM layout inspection
├── test_stage_b_v6.py # ❌ Stage B v6 (hardcoded offsets, crashes)
├── test_stage_a_qk.py # ❌ Stage A v1 (broken, superseded by v2)
├── test_stage_a_minimal.py # ❌ Stage A minimal (broken, superseded by v2)
├── test_attention_path_b200.py # Full attention path test (uses naive BF16 attn)
├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM
├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + C-fragment softmax (runs, wrong output)
├── test_stage_b_afrag2.py # 🔨 Stage B: A-fragment store pattern (compiles, wrong output)
├── test_tmem_pure_fp32.py # ✅ FP32 ldst roundtrip on C-fragment: cosine 0.999999
├── test_bf16_elemwise.py # ✅ FP32→BF16→FP32 elemwise + FP32 st: cosine 0.999999
├── test_recast_minimal.py # ✅ BF16 recast ld S0→st S1 via C-fragment: cosine 0.999999
├── test_bf16_recast_simple.py # ❌ BF16 recast ld/st same region (S0): zero (can't overwrite MMA output)
├── test_tmem_copy_roundtrip.py # ❌ BF16 recast + C→A mismatch: zero
├── test_stage_b_final.py # ❌ C-fragment st + A-fragment read: NaN (physical layout mismatch)
├── test_afrag_roundtrip.py # ❌ A-frag st corrupts S0 (overlapping TMEM region)
├── diag_tmem.py # Diagnostic: TMEM layout inspection
└── ...
```
@@ -46,68 +48,66 @@ Validates the full tcgen05.mma → TMEM → epilogue → GMEM path:
- TmemAllocator for TMEM allocation/deallocation
- utils.gemm.sm100.epilogue_tma_store for the TMEM→reg→SMEM→TMA→GMEM epilogue
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20)
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20-21)
**Latest**: `tests/test_stage_b_v7.py`
**Status**: Kernel compiles and runs without crashing. Identity softmax produces wrong output (cosine ≈ -0.02).
**Core Problem**: The C-fragment (MMA accumulator) and A-fragment (MMA A-operand from TMEM) use **different physical TMEM address mappings** for the same logical (M,K) position. The softmax writes P via one mapping, but the PV MMA reads via the other. This produces garbage.
**What was fixed today:**
#### What's Been Proven
1. **PipelineUmmaAsync consumer group size crash (THE bug):**
`PipelineUmmaAsync` with `Agent.Thread` requires **thread count** (128), NOT warp count (4), for the consumer group. fmha.py uses `32 * len(softmax_warp_ids) = 128`. Using 4 caused `CUDA_ERROR_LAUNCH_FAILED` (not a deadlock — the barrier reached wrong threshold causing illegal TMEM access).
```python
# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4)
# CORRECT (matches fmha.py):
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
```
| Test | Pattern | Result | Why |
|------|---------|--------|-----|
| test_tmem_pure_fp32 | FP32 ld→st, same C-fragment layout | ✅ cos=0.999999 | C-fragment addresses self-consistent |
| test_bf16_elemwise | FP32→BF16→FP32 elemwise, C-fragment st | ✅ cos=0.999999 | BF16 conversion works, C-fragment st works |
| test_recast_minimal | BF16 recast ld S0→st S1, C-fragment | ✅ cos=0.999999 | Recast works when writing to different region |
| test_bf16_recast_simple | BF16 recast ld/st same region S0 | ❌ zero | Can't overwrite MMA output in same region |
| test_stage_b_final | C-fragment st → A-fragment read (S1) | ❌ NaN | C-layout ≠ A-layout physical addresses |
| test_stage_b_afrag2 | A-fragment st (backward FMHA pattern) | ❌ cos=-0.02 | Store + PV MMA layout compatible, but register data flow wrong |
2. **TMEM offset computation (no more hardcoding):**
- `s_cols = find_tmem_tensor_col_offset(tStS) = 128` — QK accumulator physical TMEM columns
- `o_cols = find_tmem_tensor_col_offset(tOtO) = 128` — PV accumulator physical TMEM columns
- `tmem_s0_offset = 0, tmem_p0_offset = 32, tmem_o0_offset = 128` — matches fmha.py
- `find_tmem_tensor_col_offset(tOrP_sliced) = 32800 = 0x8020` — 0x8000 is TMEM space tag, column offset = 32
- Total: 256 TMEM cols (verified by `get_num_tmem_alloc_cols`)
#### Root Cause: C-fragment vs A-fragment Physical TMEM Layout
3. **P fragment construction (matching fmha.py):**
```python
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) # A-layout from PV MMA
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tOrP0 = cute.make_tensor(tOrP.iterator + 2 * tmem_p0_offset, tOrP.layout)
```
Previously used `cute.composition` on C-layout — wrong, must use PV MMA's A-layout.
From the CUTLASS source (`mma_traits_sm100.hpp`):
4. **V SMEM aliasing:**
V shares the same SMEM as K with a different layout interpretation:
```python
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
tCrV = pv_mma.make_fragment_B(sV) # Uses MN-major V layout
```
**C-fragment (MMA accumulator, FP32):**
- Layout: `((128,128),1,1):((65536,1),0,0)`**virtual** layout
- Physical TMEM addresses determined by the MMA hardware's accumulator write path
- St32x32bOp with C-fragment layout writes to C-fragment physical addresses
**What's still broken:**
**A-fragment (MMA A-operand from TMEM, BF16, K-major, M=128):**
- Layout: `((128,16),1,4):((65536,1),0,16)`**physical** TMEM layout
- A[m, k_inner] → `tmem[dp=m, col=base + 16*mma_k + k_inner]`
- BK=64 = 4 K=16 MMA atoms, NOT one K=64 atom
- The 4D fragment partition order is NOT the physical TMEM order
The identity softmax C→A layout transform produces garbage output (cosine ≈ -0.02). The kernel runs, Stage A (Q@K^T) gives cosine 0.999999, but the full (Q@K^T)@V pipeline is wrong. The issue is in the tcgen05.ld/st identity softmax path — either the ld/st copy atoms, the register conversion, or the A-layout write positions are incorrect.
**The St32x32bOp with C-fragment composition writes to C-layout physical addresses. The PV MMA reads from A-layout physical addresses. These are different physical locations.**
**Bisection results:**
- ✅ Stage B minimal (no pipeline, no softmax): runs, NaN (expected — no C→A transform)
- ✅ Stage B pipeline-only (PipelineUmmaAsync, no ld/st): runs, NaN (expected)
- 🔨 Stage B full (identity softmax): runs, cosine -0.02 (wrong — softmax transform is broken)
- All three crash with consumer_group=4, all run with consumer_group=128
#### Forward FMHA's Approach (FP16 Only!)
**TMEM layout diagnostic data:**
Forward FMHA uses a recast pattern to pack 2×FP16 into 1×FP32 register, then St32x32bOp writes to a C-fragment composition subview. **But forward FMHA explicitly rejects BF16:**
```python
if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}:
raise ValueError(in_dtype must be Float8E4M3FN or Float16)
```
QK accumulator C fragment:
tStS.layout = ((128,128),1,1):((65536,1),0,0)
cute.size = 16384, cute.cosize = 8323200
find_tmem_tensor_col_offset = 128
The recast softmax path is validated for FP16, NOT BF16. Our BF16 use is outside the tested path.
PV A-fragment (P operand):
tOrP_sliced.layout = ((128,16),1,4):((65536,1),0,16)
cute.size = 8192, cute.cosize = 8323136
find_tmem_tensor_col_offset = 32800 = 0x8020 (0x8000 tag + col 32)
```
#### Backward FMHA's Approach (BF16 Supported)
Backward FMHA writes dV to TMEM using the A-fragment layout:
1. `tdVrP_iter = cute.recast_ptr(tSTtST.iterator, dtype=self.element_dtype)` — recast C-fragment iterator to BF16
2. `tdVrP = cute.make_tensor(tdVrP_iter, tOrP.layout)` — A-fragment layout, C-fragment base
3. `tmem_store_atom = cute.make_copy_atom(St32x32bOp(Repetition(8)), self.element_dtype)` — BF16 store atom
4. Quantize via `make_rmem_tensor(input.shape, element_dtype)` + `.load()/.store(v.to(element_dtype))` — true BF16 register, NOT recast
5. Reshape: `cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape))` — match store partition shape
This compiles and runs for us (no crash), but the output is still wrong (cosine -0.02). The remaining issue is the **register layout mismatch**:
- Load partition (C-fragment): 128 FP32 values per thread (full 128×128 QK tile)
- Store partition (A-fragment): 64 BF16 values per thread (128×64 P tile for PV MMA K=64)
- The backward FMHA uses `quantize()` + reshape, but our element counts differ because the QK tile is 128×128 while P only needs 128×64
#### Next Steps for Stage B
1. **Fix the register data flow** — properly subselect the P-relevant 64 BF16 columns from the 128 FP32 load columns, or use the backward FMHA's PdO MMA tiler (M=128, N=64) instead of (M=128, N=128)
2. **Verify A-fragment store roundtrip** — write known BF16 values via A-fragment store, have PV MMA read them back via A-fragment, confirm the physical TMEM addresses match
3. **Once data flow is correct, add online softmax** (Stage C)
### 🔨 Stage C: Online Softmax — AFTER B
@@ -168,9 +168,20 @@ After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast
## Critical APIs & Lessons
### PipelineUmmaAsync consumer group size — THE MAY 20 BUG
### C-fragment ≠ A-fragment TMEM Physical Layout — THE MAY 20-21 FINDING
**For `Agent.Thread` groups in `PipelineUmmaAsync`: use thread count, NOT warp count.**
**The St32x32bOp with C-fragment composition writes to C-layout physical TMEM addresses. The PV MMA reads from A-layout physical TMEM addresses. These are DIFFERENT physical locations for the same logical (M,K) position.**
For the softmax to work, P must be written to TMEM using the A-fragment's physical layout, not the C-fragment's. The backward FMHA does this correctly by:
1. Creating the store destination with A-fragment layout + recast C-fragment iterator
2. Using a BF16 St32x32bOp atom
3. True BF16 register (not FP32 recast) via quantize() pattern
### Forward FMHA Recast Pattern — FP16 ONLY
The `cute.recast_ptr` + `.store(v.to(FP16))` pattern for packing 2×16-bit into 1×FP32 register is validated for FP16 only. BF16 is rejected in forward FMHA. The BF16 recast produces zero output when writing to the same TMEM region as the MMA output, and NaN when writing to a different region read via A-fragment.
### PipelineUmmaAsync consumer group size — thread count, NOT warp count
```python
# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
@@ -180,18 +191,32 @@ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) # warp count
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(softmax_warp_ids)) # thread count
```
This applies to ALL PipelineUmmaAsync consumers where the consumer is multiple warps. fmha.py line 671: `self.threads_per_warp * len(self.softmax0_warp_ids) = 32 * 4 = 128`.
**Note:** The earlier README incorrectly stated that warp count was correct. That was wrong. The `Agent.Thread` agent type measures group size in threads.
### TMEM offset arithmetic
- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count (with 0x8000 tag for A-fragments)
- QK accumulator C fragment: 128 TMEM columns
- PV A-fragment: offset 0x8020 = tag(0x8000) + col(32) — the 0x8000 is a TMEM memory-space identifier
- P OVERLAPS S in TMEM — P is written at column 32 within the S region (C-layout columns 0..127)
- `tOrP0 = cute.make_tensor(tOrP.iterator + acc_dtype.width // q_dtype.width * tmem_p0_offset, tOrP.layout)` — A-fragment offset scaled by dtype width ratio (F32/BF16 = 2)
### A-fragment iterator must use recast C-fragment pointer
When creating the P tensor for PV MMA's A-operand, the iterator must be the C-fragment's iterator recast to BF16:
```python
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
```
Without the recast, the A-fragment addresses are computed from an FP32 pointer base, giving wrong physical TMEM addresses (illegal memory access crash).
### V SMEM aliasing (K and V share SMEM)
```python
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
tCrV = pv_mma.make_fragment_B(sV)
```
### `make_trivial_tiled_mma` has two overloads
```python
@@ -204,16 +229,6 @@ make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode,
acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)
```
### V SMEM aliasing (K and V share SMEM)
```python
# K and V share the same SMEM buffer, but with different layouts:
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
tCrV = pv_mma.make_fragment_B(sV)
```
### Other APIs discovered from Stage A
1. **`cute.Tensor` API** — `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)`
@@ -234,12 +249,13 @@ tCrV = pv_mma.make_fragment_B(sV)
- **vLLM repo**: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
- **Pseudocode**: `/root/fragile-kernel-example/README.md` — authoritative per-tile attention flow
- **fmha.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
- **fmha_bwd.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py`
## 4-Stage Build Plan
| Stage | Goal | Status |
|-------|------|--------|
| A | Bare Q@K^T via tcgen05.mma → TMEM → GMEM | ✅ COMPLETE |
| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 Runs without crash, identity softmax produces wrong output |
| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 A-fragment store compiles, register data flow needs fixing |
| C | Online softmax between MMA1 and MMA2 (the hard part) | ⬜ TODO |
| D | FP8 paged KV gather + dequant (replace BF16 TMA load) | ⬜ TODO |

85
tests/diag_layouts.py Normal file
View File

@@ -0,0 +1,85 @@
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
a_dtype = BFloat16; b_dtype = BFloat16
from cutlass.cute.nvgpu import OperandMajorMode
a_major = OperandMajorMode.K
b_major = OperandMajorMode.K
mma_tiler_mn = (128, 128)
qk_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
a_dtype, b_dtype, OperandMajorMode.K, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler_mn)
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (*mma_tiler_mn, pv_inst_k * 4)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + Float32.width // BFloat16.width * 32,
tOrP.layout)
print('=== Layout diagnostics ===')
print('tStS.layout:', tStS.layout)
print('tStS.size:', cute.size(tStS))
print('tStS s_cols:', find_tmem_tensor_col_offset(tStS))
print()
print('tOtO.layout:', tOtO.layout)
print('tOtO.size:', cute.size(tOtO))
print('tOtO o_cols:', find_tmem_tensor_col_offset(tOtO))
print()
print('tOrP.layout:', tOrP.layout)
print('tOrP.size:', cute.size(tOrP))
print('tOrP0.layout:', tOrP0.layout)
print('tOrP0.size:', cute.size(tOrP0))
print()
tilePlikeFP32 = 128 * 16 // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
print('tStS_P_layout:', tStS_P_layout)
print()
# LOAD
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
thr_load = tiled_tmem_load.get_slice(0)
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print('LOAD tTMEM_LOADcS.shape:', tTMEM_LOADcS.shape)
print('LOAD per-thread elements:', cute.size(tTMEM_LOADcS))
# STORE (composition)
tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(0)
tTMEM_STOREcS = thr_store.partition_S(cute.make_identity_tensor(tStS_P.shape))
print('STORE tTMEM_STOREcS.shape:', tTMEM_STOREcS.shape)
print('STORE per-thread elements:', cute.size(tTMEM_STOREcS))
# What about the tOrP0 shape for store?
print()
print('tOrP0.shape:', tOrP0.shape if hasattr(tOrP0, 'shape') else 'N/A')
# tOrP0 is BF16 so we'd need a BF16 store atom - but cute.copy requires equal bit widths
# The F32 store to a BF16 target doesn't work either
# This is the fundamental tension

View File

@@ -0,0 +1,175 @@
"""Test: ld FP32 from S0, st BF16 to A-fragment layout tdVrP,
ld BF16 back from tdVrP, epi the result.
If this works, the A-fragment store is correct and the issue is in the PV MMA."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class AFragRoundtrip:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols # Only need S region
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# A-fragment for pv_mma
pv_thr = pv_mma.get_slice(0)
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tdVrP = cute.make_tensor(tOrP.iterator, tOrP.layout)
# TMEM ld (C-fragment)
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
# TMEM st (A-fragment layout, BF16)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
thr_st = tiled_st.get_slice(sfw)
tStP = thr_st.partition_D(tdVrP)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS: ld FP32 → BF16 → st A-frag → ld A-frag BF16 → FP32 → epi
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# 1. ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# 2. Convert FP32 → BF16
rBf16 = cute.make_rmem_tensor(tLdcS.shape, self.q_dtype)
for i in cutlass.range(cute.size(rLd), vectorize=True):
rBf16[i] = rLd[i].to(self.q_dtype)
# 3. st BF16 to A-fragment layout
cute.copy(tiled_st, rBf16, tStP)
cute.arch.fence_view_async_tmem_store()
# 4. Store to A-frag done. Check if S0 epi still works.
si_handle.release()
tCtS0 = cute.make_tensor(tmem_ptr, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtS0, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = AFragRoundtrip(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('A-frag roundtrip: cos={:.6f} (expect 0.999 from Stage A)'.format(cos))
if __name__ == '__main__':
test()

216
tests/test_b_afrag2.py Normal file
View File

@@ -0,0 +1,216 @@
"""Stage B: Store P via A-fragment layout with recast C-fragment iterator.
Matching the backward FMHA pattern exactly:
1. tOrP = pv_thr.make_fragment_A(tP)[None,None,None,0] (A-fragment layout)
2. tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=BF16) (C-fragment base, recast to BF16)
3. tdVrP = cute.make_tensor(tdVrP_iter + offset, tOrP.layout)
4. make_tmem_copy(St32x32bOp(Repetition(8)), BF16, tdVrP)
5. Store BF16 registers to tdVrP
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBAfrag2:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 0
self.tmem_o0_offset = self.s_cols * 2
self.tmem_alloc_cols = 512
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# ── P A-fragment (backward FMHA pattern) ──
# 1. Get A-fragment layout from pv_mma
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
# 2. Recast C-fragment iterator to BF16 (matching backward FMHA line 962)
tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
# 3. Create store target with A-fragment layout + recast iterator
# The offset for P within TMEM: qk_acc_dtype.width / q_dtype.width * tmem_p0_offset
# But since we recast to BF16, the offset should be in BF16 units
tdVrP = cute.make_tensor(
tdVrP_iter + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# PV MMA's A-fragment (for reading)
tOrP0 = cute.make_tensor(tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.s_cols, tOrP.layout)
# ── TMEM LOAD from C-fragment ──
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
# ── TMEM STORE via A-fragment layout (backward FMHA pattern) ──
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
thr_st = tiled_st.get_slice(sfw)
tStP = thr_st.partition_D(tdVrP)
# Source identity for store (A-fragment shape)
cS_P = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[2]))
tScS_P = pv_thr.partition_A(cS_P)
tStcS = thr_st.partition_S(tScS_P)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
print(f'[A2] tdVrP.layout: {tdVrP.layout}')
print(f'[A2] tOrP0.layout: {tOrP0.layout}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
# PV MMA
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# SOFTMAX/EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Convert FP32 → BF16 (backward-style: true BF16 register, not recast)
rBf16 = cute.make_rmem_tensor(tStcS.shape, self.q_dtype)
for i in cutlass.range(cute.size(rLd), vectorize=True):
rBf16[i] = rLd[i].to(self.q_dtype)
# Store BF16 to TMEM via A-fragment layout
cute.copy(tiled_st, rBf16, tStP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBAfrag2(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B A-frag2 (backward FMHA pattern): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

237
tests/test_bf16_elemwise.py Normal file
View File

@@ -0,0 +1,237 @@
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class BF16Elemwise:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
# LD and ST on same layout
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS = thr_st.partition_D(tStS1)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld → FP32 st, NO recast
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Direct copy: ld register → st register (same shape since same layout)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
# Since ld and st have the same C-fragment layout and same identity tensor,
# the register shapes should match. Copy element by element.
for i in cutlass.range(cute.size(rLd), vectorize=True):
rSt[i] = rLd[i].to(self.q_dtype).to(self.qk_acc_dtype)
cute.copy(tiled_st, rSt, tStS)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S1
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = BF16Elemwise(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('BF16 elemwise ld→st copy roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

237
tests/test_bf16_pack.py Normal file
View File

@@ -0,0 +1,237 @@
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class BF16PackTest:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
# LD and ST on same layout
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS = thr_st.partition_D(tStS1)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld → FP32 st, NO recast
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Direct copy: ld register → st register (same shape since same layout)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
# Since ld and st have the same C-fragment layout and same identity tensor,
# the register shapes should match. Copy element by element.
for i in cutlass.range(cute.size(rLd), vectorize=True):
rSt[i] = rLd[i].to(self.q_dtype).to(self.qk_acc_dtype)
cute.copy(tiled_st, rSt, tStS)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S1
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = BF16PackTest(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('BF16 elemwise ld→st to S1: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,242 @@
"""Test BF16 recast pattern with FULL C-fragment layout (no subview).
ld from S0 (full 128x128), recast BF16, st to S1 (full 128x128 at offset 128).
Since both ld and st use the same layout, the recast should work (shapes match).
Then epi reads S1. If this works, the recast pattern IS correct for same-layout cases.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class BF16RecastFull:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS1_dst = thr_st.partition_D(tStS1)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# BF16 recast pattern (same layout for ld and st, shapes match)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
rSt_e = cute.make_tensor(cute.recast_ptr(rSt.iterator, dtype=self.q_dtype), rLd.layout)
frg_cnt = 4
frg_tile = cute.size(rLd) // frg_cnt
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
rSt_e_frg = cute.logical_divide(rSt_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
v = rLd_frg[None, j].load()
rSt_e_frg[None, j].store(v.to(self.q_dtype))
cute.copy(tiled_st, rSt, tStS1_dst)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S1
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = BF16RecastFull(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('BF16 recast full layout (ld S0, st S1, epi S1): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,239 @@
"""Simplest BF16 recast test: ld FP32 from S0, use recast to convert,
st back to S0, then epi reads S0. Single region, no subview."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class SimpleBF16Recast:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
# ld and st on the SAME tensor (S0), same layout
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS_dst = thr_st.partition_D(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# BF16 recast pattern (identity, no math)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
rSt_e = cute.make_tensor(cute.recast_ptr(rSt.iterator, dtype=self.q_dtype), rLd.layout)
frg_cnt = 4
frg_tile = cute.size(rLd) // frg_cnt
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
rSt_e_frg = cute.logical_divide(rSt_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
v = rLd_frg[None, j].load()
rSt_e_frg[None, j].store(v.to(self.q_dtype))
cute.copy(tiled_st, rSt, tStS_dst)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S0
tCtS0 = cute.make_tensor(tmem_ptr, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS0, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = SimpleBF16Recast(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('BF16 recast same region (ld/st S0, epi S0): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,85 @@
"""
BF16 Packing Diagnostic: Run identity softmax with K=V=randn,
compare output vs reference to identify the error pattern.
"""
import torch, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from test_stage_b_v7 import StageBIdentitySoftmax
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf # identity softmax: (Q @ K^T) @ V
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'\nCosine: {cos:.6f}')
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
# Key diagnostic: compare Q@K^T stage (which we know is correct) vs PV stage
# If Q@K^T is correct but PV is wrong, the output will show the PV error pattern
# With identity softmax, P = Q@K^T. So output = P @ V = (Q@K^T) @ V
# If V is being read wrong, the output will be P @ V_permuted
# Check: is the output a permutation of the reference?
# Sort both and compare
out_sorted = out.flatten().sort()[0]
ref_sorted = ref.flatten().sort()[0]
cos_sorted = torch.nn.functional.cosine_similarity(out_sorted.unsqueeze(0), ref_sorted.unsqueeze(0)).item()
print(f'\nCosine after sorting: {cos_sorted:.6f}')
print(f'(If ~1.0, output is a permutation of reference)')
# Check: does output match P @ V^T? (V transposed)
ref_vt = qf @ kvf.T @ kvf.T # (Q@K^T) @ V^T
cos_vt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_vt.flatten().unsqueeze(0)).item()
print(f'Cosine with P @ V^T: {cos_vt:.6f}')
# Check: does output match P^T @ V? (P transposed)
# P = Q@K^T, so P^T = K@Q^T
ref_pt = kvf @ qf.T @ kvf
cos_pt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_pt.flatten().unsqueeze(0)).item()
print(f'Cosine with P^T @ V: {cos_pt:.6f}')
# Check: is output simply the Q@K^T scores (MMA2 produced identity)?
# If MMA2 didn't run or produced P unchanged, output = P = Q@K^T
qkt = qf @ kvf.T
cos_qkt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), qkt.flatten().unsqueeze(0)).item()
print(f'Cosine with just Q@K^T: {cos_qkt:.6f}')
# Check: all output rows identical? (means P has identical rows, like all-ones)
all_same = torch.allclose(out[0], out[1], atol=1e-3)
print(f'All output rows identical: {all_same}')
if not all_same:
cos_r01 = torch.nn.functional.cosine_similarity(out[0].unsqueeze(0), out[1].unsqueeze(0)).item()
print(f'Cosine between row 0 and row 1: {cos_r01:.6f}')
# Check: is output just V (P=I case)?
cos_v = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), kvf.flatten().unsqueeze(0)).item()
print(f'Cosine with V alone: {cos_v:.6f}')
# Print some output statistics
print(f'\nOutput stats: min={out.min().item():.4f}, max={out.max().item():.4f}, mean={out.mean().item():.4f}')
print(f'Ref stats: min={ref.min().item():.4f}, max={ref.max().item():.4f}, mean={ref.mean().item():.4f}')

133
tests/test_packing_diag.py Normal file
View File

@@ -0,0 +1,133 @@
"""
BF16 Packing Diagnostic: Write specific F32 bit patterns to P TMEM via St32x32bOp.
Then run PV MMA with V=identity. Output = P (as MMA reads it).
This reveals the BF16 packing order within F32 TMEM words.
Strategy:
- Skip MMA1 entirely (no QK computation)
- Write packed F32 values where low BF16 ≠ high BF16
- Use K=V=identity so output[i,j] = P[i,j]
- Compare output against both packing orders
"""
import torch, struct, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
# Use the existing diag test that writes P=all-ones, but modify to write specific patterns
# The diag test is test_stage_b_diag.py - let me copy and modify v7 instead
# Actually, let me just use the EXISTING kernel (test_stage_b_v7) with:
# 1. K = identity matrix
# 2. The identity softmax running normally (it writes softmax(Q@K^T) as P)
# 3. If Q is also identity, then Q@K^T = I@I = I, softmax(I) = softmax of identity
# 4. That's not what we want either.
# Better: use the diag test (writes P=all-ones) with K=identity
# Output should be all-ones * I = all-ones. That gives no new info.
# The REAL test: modify the kernel to write specific F32 values to the BF16 recast view
# Instead of writing 1.0 BF16, write alternating different values
# Simplest approach: modify test_stage_b_diag.py to write j+1 in BF16 for each position
# Then with V=identity, output[i,j] = j+1
# But modifying the kernel requires JIT changes. Let me use the existing v7 kernel
# with a clever input choice instead.
# Approach: Use the identity softmax (which writes Q@K^T scores as P)
# With Q=randn and K=identity: Q@K^T = Q (since K=I)
# Then P = softmax(Q) and output = softmax(Q) @ V
# With V=I: output = softmax(Q)
# This tests the FULL pipeline but doesn't isolate packing.
# Let me just write the packing test kernel from scratch, minimally.
# It only needs: TMA load V, write F32 to P TMEM, PV MMA, epilogue.
# Actually, the simplest isolation test:
# Use the existing test with P=all-ones and V=K (K=randn)
# We already know cos=0.08 for this. The 0.08 is not 0 or 1.
# If I use V where V[j] = 1 if j even, 0 if j odd, then:
# With correct packing: output = (number of ones in even K positions) per row
# This doesn't help because P=all-ones means all K positions are 1.0
# I think the key issue is that with P=all-ones (cos=0.08), the output should be
# EXACTLY sum(V, dim=0) for each row. Let me compare more carefully.
# Let me run the EXISTING P=all-ones diag test and compare the output values
# against the reference. The PATTERN of errors will tell us about packing.
print("Running existing P=all-ones diag with K=V=randn")
print("Comparing output vs reference to identify the error pattern")
import cutlass.cute as cute
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
import cuda.bindings.driver as cuda
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
kvf = kv[:,:,0].float()
ref = torch.ones(128, 128, dtype=torch.float32, device='cuda') @ kvf
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
from test_stage_b_diag import StageBDiag
kernel = StageBDiag(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling diag test...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'\nCosine: {cos:.6f}')
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
print(f'Ratio out/ref row 0[:8]: {(out[0,:8] / ref[0,:8]).tolist()}')
# Check if output is a scaled version of the reference
ratio = out[0] / ref[0]
ratio_clean = ratio[torch.isfinite(ratio)]
print(f'Ratio mean: {ratio_clean.mean().item():.6f}, std: {ratio_clean.std().item():.6f}')
# Check if odd/even columns have different ratios
ratio_even = (out[0, 0::2] / ref[0, 0::2])[torch.isfinite(out[0, 0::2] / ref[0, 0::2])]
ratio_odd = (out[0, 1::2] / ref[0, 1::2])[torch.isfinite(out[0, 1::2] / ref[0, 1::2])]
print(f'Even col ratio mean: {ratio_even.mean().item():.6f}')
print(f'Odd col ratio mean: {ratio_odd.mean().item():.6f}')
# Check if output matches a different V interpretation
# What if the V SMEM is being read in a different column order?
# Compute: what if V columns were permuted?
# With P=all-ones, output[i] = sum(V[:,j]) for each j
# This is the column sum of V, broadcast to all rows
# If V is read in a different order, the output would be the sum of
# differently-ordered V columns, but still all the same sum
# So output should be uniform across columns = sum of all V values per column
# But the output IS uniform (all rows identical). The VALUES are wrong.
# That means the SUM is wrong, which means the V values being read are wrong.
# Let me check: does the output column structure match any known transform of V?
print(f'\nV (K) row 0[:8]: {kvf[0,:8].tolist()}')
print(f'V col sums (reference): {ref[0,:4].tolist()}')
print(f'Output col 0[:4]: {out[0,:4].tolist()}')
# What if V is being read transposed?
# sum(V[j,:]) per column j = row sums of V
row_sums = kvf.sum(dim=1)
col_sums = kvf.sum(dim=0)
print(f'V row sums[:4]: {row_sums[:4].tolist()}')
print(f'V col sums[:4]: {col_sums[:4].tolist()}')

119
tests/test_pair_swap.py Normal file
View File

@@ -0,0 +1,119 @@
"""
BF16 Pair-Swap Test: compare kernel output against reference with V rows
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
"""
import torch, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from test_stage_b_v7 import StageBIdentitySoftmax
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kvf = kv[:,:,0].float()
# Standard reference: P @ V where P = Q @ K^T
ref = qf @ kvf.T @ kvf
# Pair-swapped reference: if BF16 within each F32 are swapped,
# MMA reads K[2j] and K[2j+1] swapped, which means V rows are swapped in pairs
# Output = P @ V_with_row_pairs_swapped
V_swap = kvf.clone()
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
ref_swap = qf @ kvf.T @ V_swap
# Also try: just the P@K^T with V where every 2 consecutive rows are swapped
# but starting from different offsets
# Try 1-based offset: swap rows (1,2), (3,4), ...
V_swap1 = kvf.clone()
V_swap1[1::2], V_swap1[2::2] = kvf[2::2].clone(), kvf[1::2].clone()
V_swap1[0] = kvf[0] # row 0 unchanged
ref_swap1 = qf @ kvf.T @ V_swap1
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
cos_swap1 = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap1.flatten().unsqueeze(0)).item()
print(f'\nCosine with standard ref: {cos_ref:.6f}')
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
print(f'Cosine with offset-1 swapped: {cos_swap1:.6f}')
# Also try: what if the entire P row is shifted by some amount?
# Or what if P columns are reordered according to the A-fragment's K partitioning?
# The A-fragment has K laid out as (16 inner, 4 outer chunks of 16)
# If the store writes columns sequentially but the MMA reads them in the
# chunked order, we'd get a specific permutation
# The A-fragment K layout: (col, k0, k1) where col=0..15, k0=0..3, k1=0..1
# K position = col + 16*k0 + 64*k1
# If the store writes K sequentially (0,1,2,...,127) but the MMA reads
# in the fragment order, the mapping would be:
# Fragment index (col, k0, k1) -> K position (col + 16*k0 + 64*k1)
# But this is sequential (0,1,2,...), so no permutation.
# UNLESS the store doesn't fill the fragment's K blocks correctly.
# Try: what if only the FIRST 64 K values (not 128) are filled?
# The store writes 64 F32 = 128 BF16. But what if the MMA's A-fragment
# K dimension is 64 (not 128)? Then only 64 K values have P, rest are garbage.
# But we already verified nblk_pv=4 and the fragment covers 128 BF16.
# Try: what if the 64 F32 columns in the store map to 128 BF16 K positions
# but with the packing where BF16[K] and BF16[K+1] come from the same F32 word
# and K is determined by the A-fragment's partition order?
# The A-fragment reads K in groups of 16 (inner K), then 4 outer groups, then 2 BF16 per F32
# So K values are accessed as: [group0_bf16_0, group0_bf16_1, group1_bf16_0, group1_bf16_1, ...]
# where group = (col, k0, k1) and bf16_0/bf16_1 are the two BF16 in each F32 word
# This is getting complex. Let me just check the specific permutation by
# matching output columns to reference columns for row 0.
# For each output column j, find which reference column it matches
# (using the entire row as a signature)
print('\nTrying to identify the column permutation for row 0...')
# Since all values might not be unique, use the dot product with a known vector
# Actually, the simplest: for output[0, j], check if it equals ref[0, k] for any k
# But we need a unique signature. Let's use the output of the ENTIRE row.
# Compare output row i with reference rows
for i in [0, 1, 2]:
best_cos = 0
best_j = -1
for j in range(128):
c = torch.nn.functional.cosine_similarity(out[i].unsqueeze(0), ref[j].unsqueeze(0)).item()
if c > best_cos:
best_cos = c
best_j = j
print(f' Output row {i} best matches ref row {best_j} (cos={best_cos:.4f})')
# Also: compare output column j with reference columns
print('\nColumn matching (output col j vs ref col k):')
for j in [0, 1, 2, 3, 4, 5, 6, 7]:
best_cos = 0
best_k = -1
for k in range(128):
c = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
if c > best_cos:
best_cos = c
best_k = k
print(f' Output col {j} best matches ref col {best_k} (cos={best_cos:.4f})')

95
tests/test_pair_swap2.py Normal file
View File

@@ -0,0 +1,95 @@
"""
BF16 Pair-Swap Test: compare kernel output against reference with V rows
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
Also tries to identify the exact column permutation.
"""
import torch, sys
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from test_stage_b_v7 import StageBIdentitySoftmax
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
# Pair-swapped: V rows 0↔1, 2↔3, 4↔5, ...
V_swap = kvf.clone()
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
ref_swap = qf @ kvf.T @ V_swap
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
print(f'\nCosine with standard ref: {cos_ref:.6f}')
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
# Column permutation identification
print('\nColumn permutation (output col j matches ref col k):')
perm = []
for j in range(min(16, 128)):
best_cos = 0
best_k = -1
for k in range(128):
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
if c_val > best_cos:
best_cos = c_val
best_k = k
perm.append(best_k)
if j < 16:
print(f' out col {j} -> ref col {best_k} (cos={best_cos:.4f})')
# Check if permutation has a pattern
print(f'\nPermutation (first 16): {perm}')
diffs = [perm[j+1] - perm[j] for j in range(len(perm)-1)]
print(f'Permutation diffs: {diffs}')
# Check: is perm[j] = j with every pair swapped? (0→1, 1→0, 2→3, 3→2, ...)
expected_pair_swap = [j^1 for j in range(128)] # XOR with 1 swaps even/odd
matches_pair_swap = all(perm[j] == expected_pair_swap[j] for j in range(len(perm)))
print(f'Matches pair-swap (0↔1, 2↔3, ...): {matches_pair_swap}')
# Also check the full permutation for all 128 columns
full_perm = []
for j in range(128):
best_cos = 0
best_k = -1
for k in range(128):
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
if c_val > best_cos:
best_cos = c_val
best_k = k
full_perm.append(best_k)
print(f'\nFull permutation: {full_perm}')
# Check if it's pair-swap for all columns
all_pair_swap = all(full_perm[j] == (j^1) for j in range(128))
print(f'All columns match pair-swap: {all_pair_swap}')
# If not pair-swap, what's the pattern?
if not all_pair_swap:
# Check if it's a 16-element block permutation
for block in range(8):
block_perm = full_perm[block*16:(block+1)*16]
print(f' Block {block}: {block_perm}')

View File

@@ -0,0 +1,237 @@
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class RecastMinimal:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
# LD and ST on same layout
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS = thr_st.partition_D(tStS1)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld → FP32 st, NO recast
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Direct copy: ld register → st register (same shape since same layout)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
# Since ld and st have the same C-fragment layout and same identity tensor,
# the register shapes should match. Copy element by element.
for i in cutlass.range(cute.size(rLd), vectorize=True):
rSt[i] = rLd[i]
cute.copy(tiled_st, rSt, tStS)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S1
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = RecastMinimal(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('BF16 recast minimal→st copy roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

210
tests/test_stage_b_afrag.py Normal file
View File

@@ -0,0 +1,210 @@
"""Stage B: Store P via A-fragment layout, not C-fragment.
Key insight from Mike: the C-fragment store and A-fragment read use different
physical TMEM address mappings. The fix is to construct the TMEM store
using the A-fragment layout (from p_tmem_s / make_fragment_A).
Steps:
1. Q@K^T → TMEM (C-fragment, offset 0)
2. ld scores from TMEM (C-fragment)
3. Convert FP32→BF16
4. st P to TMEM using A-fragment layout (p_tmem_s / tOrP0)
5. PV MMA reads from TMEM using A-fragment (same layout = same physical addresses)
6. Epilogue writes output
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBAfrag:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
# TMEM layout: S0 (scores) at offset 0, P (A-fragment) at offset 32, O at offset s_cols
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = self.s_cols
self.tmem_alloc_cols = self.s_cols + self.o_cols
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
# TMEM tensors
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P A-fragment (matching fmha exactly)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# ── TMEM LOAD from C-fragment ──
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
# ── TMEM STORE to A-fragment layout ──
# Use St16x128bOp with BF16 for the A-fragment layout
# (St32x32bOp only works with C-fragment layout, not A-fragment)
tmem_st = cute.make_copy_atom(tcgen05.copy.St16x128bOp(tcgen05.copy.Repetition(16)), self.q_dtype)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tP)
thr_st = tiled_st.get_slice(sfw)
tStP = thr_st.partition_D(tP)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
# PV MMA
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# SOFTMAX/EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Convert FP32 → BF16 into a true BF16 register (backward-style)
rBf16 = cute.make_rmem_tensor(tLdcS.shape, self.q_dtype)
for i in cutlass.range(cute.size(rLd), vectorize=True):
rBf16[i] = rLd[i].to(self.q_dtype)
# Store BF16 to TMEM via A-fragment layout
cute.copy(tiled_st, rBf16, tStP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBAfrag(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B A-frag store: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,217 @@
"""Stage B: Store P via A-fragment layout with recast C-fragment iterator.
Matching the backward FMHA pattern exactly:
1. tOrP = pv_thr.make_fragment_A(tP)[None,None,None,0] (A-fragment layout)
2. tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=BF16) (C-fragment base, recast to BF16)
3. tdVrP = cute.make_tensor(tdVrP_iter + offset, tOrP.layout)
4. make_tmem_copy(St32x32bOp(Repetition(8)), BF16, tdVrP)
5. Store BF16 registers to tdVrP
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBAfrag2:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 0
self.tmem_o0_offset = self.s_cols * 2
self.tmem_alloc_cols = 512
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# ── P A-fragment (backward FMHA pattern) ──
# 1. Get A-fragment layout from pv_mma
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
# 2. Recast C-fragment iterator to BF16 (matching backward FMHA line 962)
tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
# 3. Create store target with A-fragment layout + recast iterator
# The offset for P within TMEM: qk_acc_dtype.width / q_dtype.width * tmem_p0_offset
# But since we recast to BF16, the offset should be in BF16 units
tdVrP = cute.make_tensor(
tdVrP_iter + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# PV MMA's A-fragment (for reading)
tOrP0 = cute.make_tensor(tOrP.iterator, tOrP.layout)
# ── TMEM LOAD from C-fragment ──
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
# ── TMEM STORE via A-fragment layout (backward FMHA pattern) ──
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
thr_st = tiled_st.get_slice(sfw)
tStP = thr_st.partition_D(tdVrP)
# Source identity for store (A-fragment shape)
cS_P = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[2]))
tScS_P = pv_thr.partition_A(cS_P)
tStcS = thr_st.partition_S(tScS_P)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
print(f'[A2] tdVrP.layout: {tdVrP.layout}')
print(f'[A2] tOrP0.layout: {tOrP0.layout}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
# PV MMA
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# SOFTMAX/EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Convert FP32 → BF16 (backward-style: true BF16 register, not recast)
rBf16 = cute.make_rmem_tensor(tStcS.shape, self.q_dtype)
for i in cutlass.range(cute.size(rLd), vectorize=True):
rBf16[i] = cutlass.BFloat16(1.0)
# Store BF16 to TMEM via A-fragment layout
# SKIP STORE
#cute.copy(tiled_st, rBf16, tStP)
#cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBAfrag2(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B A-frag2 (backward FMHA pattern): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

384
tests/test_stage_b_diag.py Normal file
View File

@@ -0,0 +1,384 @@
"""
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# 3. DIAGNOSTIC: Print per-thread element counts
if tidx == 0:
print(f"[DIAG] LOAD per-thread elements: {cute.size(tTMEM_LOADcS)}")
print(f"[DIAG] STORE per-thread elements: {cute.size(tTMEM_STOREcS)}")
# 4. Wait for scores (MMA1 must complete before we touch TMEM)
si_handle = mma_si_cons.wait_and_advance()
# 5. DIAGNOSTIC: Skip the load entirely. Write 1.0 to P via store atom.
# If store atom works: output should be correlated with P@V.
# P=all-ones means MMA2 computes sum of V columns per row.
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
# Need load rmem tensor for its layout (fmha.py pattern: recast uses load layout)
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
for j in cutlass.range(cute.size(tTMEM_STORErS_x4_e), unroll_full=True):
tTMEM_STORErS_x4_e[j] = self.q_dtype(1.0)
# 6. Store P=all-ones into A-layout TMEM
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = torch.ones(128, 128, dtype=torch.float32, device="cuda") @ kvf # P=all-ones @ V
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B DIAG: P=all-ones (store atom test)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
print(' Output row 0[:8]:', out[0,:8].tolist())
print(' Output row 1[:8]:', out[1,:8].tolist())
print(' Output row 63[:8]:', out[63,:8].tolist())
print(' Ref row 0[:8]:', ref[0,:8].tolist())
print(' Output col 0[:8]:', out[:8,0].tolist())
print(' Any zeros:', (out==0).sum().item(), 'of', out.numel())
if __name__ == '__main__':
test()

207
tests/test_stage_b_final.py Normal file
View File

@@ -0,0 +1,207 @@
"""Stage B final: ld FP32 from S0, BF16 recast, st to S1 (offset s_cols),
PV MMA reads A-fragment from S1 at the appropriate offset.
The recast pattern WORKS when writing to a different TMEM region.
PV MMA A-fragment reads from S1 with tmem_p0_offset adjustment.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBFinal:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
# S0 at offset 0, P0 (BF16 view) at offset 32 within S0, O0 at offset s_cols
# But for the recast test, we write the FULL C-fragment to S1 (offset s_cols)
# PV MMA A-fragment needs to read from S1. The A-fragment offset is:
# tOrP0 starts at tStS.iterator + (F32_width/BF16_width) * tmem_p0_offset
# For S1, we need A-fragment pointing to the start of S1.
# A-fragment offset for column c: width_ratio * c
# S1 starts at s_cols (128 FP32 columns). A-fragment for P at S1 start:
# offset = (F32_width / BF16_width) * s_cols = 2 * 128 = 256
# But tOrP layout has stride 64 for M, so base offset is in BF16 column units
# Actually, tOrP0.iterator is the base of the A-fragment in the TMEM address space.
# The A-fragment offset is (qk_acc_dtype.width / q_dtype.width) * tmem_p0_offset.
# For our case, we want P at S1 start, so tmem_p0_offset should map to s_cols.
# But tmem_p0_offset is in FP32 columns, and the A-fragment width ratio converts it.
self.tmem_s0_offset = 0
self.tmem_p0_offset = 0 # P starts at the beginning of S1
self.tmem_o0_offset = self.s_cols # O region after S1
self.tmem_alloc_cols = self.s_cols + self.o_cols # S1 + O
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]); tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # S0: MMA output
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) # S1: BF16 copy of scores (P)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P A-fragment pointing to S1 (offset s_cols)
tP = cute.make_tensor(tStS.iterator + self.s_cols, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
# Since P is at the START of S1 (offset s_cols), no additional offset needed
tOrP0 = cute.make_tensor(tOrP.iterator, tOrP.layout)
# TMEM ld/st atoms
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0); tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id)); thr_ld = tiled_ld.get_slice(sfw); thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0); tStS1_dst = thr_st.partition_D(tStS1)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])); tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS); tStcS = thr_st.partition_S(tScS)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
# PV MMA
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# SOFTMAX/EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# BF16 recast pattern (identity softmax: just FP32→BF16 packing)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
rSt_e = cute.make_tensor(cute.recast_ptr(rSt.iterator, dtype=self.q_dtype), rLd.layout)
frg_cnt = 4; frg_tile = cute.size(rLd) // frg_cnt
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
rSt_e_frg = cute.logical_divide(rSt_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
v = rLd_frg[None, j].load()
rSt_e_frg[None, j].store(v.to(self.q_dtype))
cute.copy(tiled_st, rSt, tStS1_dst); cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBFinal(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B Final (ld S0, BF16 recast, st S1, PV MMA): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

321
tests/test_stage_b_v10.py Normal file
View File

@@ -0,0 +1,321 @@
"""Stage B v10: Use FP32 ld→st on FULL (128,128) layout, not subview.
The ld reads S0 (full C-fragment), the st writes S1 (full C-fragment at offset 128),
then PV MMA reads S1 via A-fragment.
Key insight: The FP32 ld→st on the SAME layout works (cosine 0.999999).
The BF16 recast pattern with DIFFERENT layout sizes is broken.
For identity softmax, we need to write the data back to TMEM in a format
that the PV MMA's A-fragment can read. Since the C-fragment and A-fragment
both access the same physical TMEM, writing via C-fragment (store) and
reading via A-fragment (PV MMA) should work IF the data is in the right place.
The pv_mma_tiler has N=128 (from mma_tiler_mn[1]) but P's K dimension = 128.
Wait, pv_mma_tiler = (M, K, K_inner*4). For the PV MMA, A=P with K=128 (full KV dim).
But p_tmem_s is make_smem_layout_a(pv_mma, pv_mma_tiler, BF16, 1) which gives the
A-fragment layout for (128, 128, 64). The A-fragment covers K=64 at CTA level...
Hmm, pv_mma_tiler[1] should be the V dimension (N in MMA terms), not K.
Actually, for PV MMA: P (A operand, M×K) × V (B operand, K×N) → O (C operand, M×N)
With a_source=TMEM, the P operand's K dimension matches the CTA tile K.
pv_mma_tiler = (128, 128, 64): M=128, N=128, K=64.
So P is M×K = 128×64. That's the A-fragment shape.
So the A-fragment only reads 64 columns of TMEM (K=64 in 4 blocks of 16).
The C-fragment for S has 128 columns. The P region starts at offset 32 (tmem_p0_offset).
With s_cols=128 and tmem_p0_offset=32, the P data starts at column 32 and spans 64 columns
(columns 32-95), which is within the S region (columns 0-127).
So the store should write to columns 32-95 of the S region. The C-fragment composition
(128, tilePlikeFP32=64) with offset 32 does exactly this.
But we showed the subview store + recast doesn't work. Let me try: write the FULL
C-fragment (128×128) at offset 128 (tStS1), and adjust the PV MMA's A-fragment
offset so it reads from the right columns within that region.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBv10:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
self.o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = self.s_cols
self.tmem_alloc_cols = self.s_cols + self.o_cols # 256
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
cute.nvgpu.OperandMajorMode.K,
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P A-fragment
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# LD and ST copy atoms on the SAME layout (full C-fragment)
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS0) # SAME layout for ld and st
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS_dst = thr_st.partition_D(tStS0) # St writes to S0 (same as MMA output)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
# PV MMA
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS: ld from S0, FP32→BF16 elementwise, st back to S0, then MMA warp uses it for PV
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# FP32 → BF16 → FP32 elementwise (identity softmax, no math, just dtype roundtrip)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
for i in cutlass.range(cute.size(rLd), vectorize=True):
rSt[i] = rLd[i].to(self.q_dtype).to(self.qk_acc_dtype)
# FP32 st back to S0 (overwrite scores with identity-softmaxed values)
cute.copy(tiled_st, rSt, tStS_dst)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(ing_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBv10(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B v10 (FP32 ld→BF16 elementwise→st, PV MMA): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

210
tests/test_stage_b_v11.py Normal file
View File

@@ -0,0 +1,210 @@
"""Stage B v11: Backward FMHA pattern exactly.
1. ld FP32 from S0 (C-fragment)
2. Quantize FP32→BF16 (same register shape, .load()/.store())
3. Reshape BF16 register to store partition shape
4. St32x32bOp BF16 store to tdVrP (A-fragment layout, offset s_cols)
5. PV MMA reads from tOrP0 (A-fragment, offset s_cols)
6. Epilogue writes output
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBv11:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_o0_offset = self.s_cols * 2 # After S0 (128) + P (128 BF16 = 64 FP32, but A-frag offset uses width ratio)
self.tmem_alloc_cols = 512 # Enough for S0 + P + O
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# ── P A-fragment (backward FMHA pattern) ──
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
# Apply offset for P region (after S0)
p_bf16_offset = self.qk_acc_dtype.width // self.q_dtype.width * self.s_cols
tdVrP = cute.make_tensor(tOrP.iterator + p_bf16_offset, tOrP.layout)
tOrP0 = cute.make_tensor(tOrP.iterator + p_bf16_offset, tOrP.layout)
# ── TMEM LOAD from C-fragment ──
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
# ── TMEM STORE via A-fragment layout ──
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
thr_st = tiled_st.get_slice(sfw)
tStP = thr_st.partition_D(tdVrP)
tdVcST = pv_thr.partition_A(cS_id) # A-operand partition of identity
tStcS = thr_st.partition_S(tdVcST)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
# PV MMA
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
# SOFTMAX/EPILOGUE WARPS (backward FMHA quantize pattern)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# 1. ld FP32 from S0
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# 2. Quantize FP32→BF16 (backward FMHA pattern)
rBf16 = cute.make_rmem_tensor(rLd.shape, self.q_dtype)
frg_cnt = 4; frg_tile = cute.size(rLd) // frg_cnt
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
rBf16_frg = cute.make_tensor(rBf16.iterator, rLd_frg.layout)
for i in cutlass.range(frg_tile, unroll_full=True):
v = rLd_frg[None, i].load()
rBf16_frg[None, i].store(v.to(self.q_dtype))
# 3. Reshape BF16 to store partition shape
rBf16_reshaped = cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape))
# 4. Store BF16 to TMEM via A-fragment layout
cute.copy(tiled_st, rBf16_reshaped, tStP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
def test():
torch.manual_seed(42); m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBv11(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B v11 (backward FMHA pattern): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

398
tests/test_stage_b_v11b.py Normal file
View File

@@ -0,0 +1,398 @@
"""
Stage B v11b: Identity Softmax - store to composition(tP, (128,128))
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
# FIX: PV tiler swaps N and K from QK (fmha.py: pv_mma_tiler = (M, QK_K, QK_N))
# P is (M, QK_N) and V is (QK_N, D), so PV MMA has K=QK_N, N=D
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
self.mma_tiler = self.qk_mma_tiler
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
print(f'[DIAG] tP.layout: {tP.layout}')
print(f'[DIAG] tP.size: {cute.size(tP)}')
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline: write to PV MMA A-layout (tP) composed to 2D
# tP has correct A-fragment addresses. Compose to (128,128) for the store atom.
tP_2d_layout = cute.composition(tP.layout, cute.make_layout((128, 128)))
tP_2d = cute.make_tensor(tP.iterator + self.tmem_p0_offset, tP_2d_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tP_2d)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tP_2d)
cP_2d = cute.make_identity_tensor(tP_2d.shape)
tTMEM_STOREcP = thr_store.partition_S(cP_2d)
# 3. Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# 4. Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: F32 → BF16
tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_STOREtP.shape, self.qk_acc_dtype)
tTMEM_STORErP_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErP.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErP_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout TMEM
cute.copy(tiled_tmem_store, tTMEM_STORErP, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# 7. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v11b: store to tP composed (128,128)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

408
tests/test_stage_b_v12.py Normal file
View File

@@ -0,0 +1,408 @@
"""
Stage B v12: BF16 St16x128bOp store to tP A-layout, P=all-ones
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
# FIX: PV tiler swaps N and K from QK (fmha.py: pv_mma_tiler = (M, QK_K, QK_N))
# P is (M, QK_N) and V is (QK_N, D), so PV MMA has K=QK_N, N=D
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
self.mma_tiler = self.qk_mma_tiler
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
print(f'[DIAG] tP.layout: {tP.layout}')
print(f'[DIAG] tP.size: {cute.size(tP)}')
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline: write to A-fragment TMEM via BF16 store atom
# tP has the A-fragment layout with dense TMEM addresses.
# St32x32bOp is for C-fragment (sparse addressing) — WRONG for A-fragment.
# Try BF16 store atoms that match the A-fragment partition.
# First, create the P store target with the p0 offset applied.
# tP is F32 but represents BF16 positions. We need a BF16 view.
tP_bf16 = cute.make_tensor(
cute.recast_ptr(tP.iterator, dtype=self.q_dtype),
p_tmem_s_bf16_layout if False else tP.layout) # placeholder
# Actually, lets try make_tmem_copy with tP directly using F32 atom first
# If that fails, well try BF16 atoms
tP_offset = cute.make_tensor(tP.iterator + self.tmem_p0_offset, tP.layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St16x128bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
# tP_offset is F32-typed but we need BF16 for St16x128bOp
# Create BF16 view of tP
tP_bf16_offset = cute.make_tensor(
cute.recast_ptr(tP.iterator, dtype=self.q_dtype),
tP.layout)
tP_bf16_with_offset = cute.make_tensor(
tP_bf16_offset.iterator + self.tmem_p0_offset, tP_bf16_offset.layout)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tP_bf16_with_offset)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tP_bf16_with_offset)
cP = cute.make_identity_tensor(tP_bf16_with_offset.shape)
tTMEM_STOREcP = thr_store.partition_S(cP)
# 3. Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# 4. DIAGNOSTIC: Just write P=all-ones directly (skip load + F32→BF16)
# This tests whether the BF16 store atom + tP A-layout works
tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_STOREtP.shape, self.q_dtype)
for j in cutlass.range(cute.size(tTMEM_STORErP), unroll_full=True):
tTMEM_STORErP[j] = self.q_dtype(1.0)
# 5. Store into A-fragment TMEM
cute.copy(tiled_tmem_store, tTMEM_STORErP, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# 6. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = torch.ones(128, 128, dtype=torch.float32, device="cuda") @ kvf # P=all-ones @ V
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v12: BF16 store to A-fragment layout')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -41,6 +41,11 @@ class StageBIdentitySoftmax:
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
@@ -80,7 +85,7 @@ class StageBIdentitySoftmax:
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_p0_offset = 32 # Original
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
@@ -117,6 +122,16 @@ class StageBIdentitySoftmax:
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
# Introspect PV MMA atom
print(f"[ATOM] PV MMA type: {type(pv_mma)}")
print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}")
print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}")
print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}")
print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}")
# Check a_src
print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}")
print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}")
print(f"[ATOM] PV MMA op: {pv_mma.op}")
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
@@ -205,6 +220,10 @@ class StageBIdentitySoftmax:
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
@@ -220,16 +239,73 @@ class StageBIdentitySoftmax:
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
print(f'[DIAG] tStS.layout: {tStS.layout}')
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# Compute nblk_pv for diagnostics
nblk_pv = cute.size(tOrP0, mode=[2])
nblk_qk = cute.size(tCrA, mode=[2])
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
# COMPREHENSIVE LAYOUT DUMP
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
print(f'[LAYOUT] tP.layout: {tP.layout}')
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
print(f'[DIAG] tP.layout: {tP.layout}')
print(f'[DIAG] tP.size: {cute.size(tP)}')
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──

View File

@@ -0,0 +1,445 @@
"""
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
print(f'[DIAG] tStS.layout: {tStS.layout}')
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# Compute nblk_pv for diagnostics
nblk_pv = cute.size(tOrP0, mode=[2])
nblk_qk = cute.size(tCrA, mode=[2])
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
# COMPREHENSIVE LAYOUT DUMP
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
print(f'[LAYOUT] tP.layout: {tP.layout}')
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
print(f'[DIAG] tP.layout: {tP.layout}')
print(f'[DIAG] tP.size: {cute.size(tP)}')
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# 3. Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# 4. Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: F32 → BF16
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,445 @@
"""
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
print(f'[DIAG] tStS.layout: {tStS.layout}')
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# Compute nblk_pv for diagnostics
nblk_pv = cute.size(tOrP0, mode=[2])
nblk_qk = cute.size(tCrA, mode=[2])
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
# COMPREHENSIVE LAYOUT DUMP
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
print(f'[LAYOUT] tP.layout: {tP.layout}')
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
print(f'[DIAG] tP.layout: {tP.layout}')
print(f'[DIAG] tP.size: {cute.size(tP)}')
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# 3. Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# 4. Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: F32 → BF16
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,445 @@
"""
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
print(f'[DIAG] tStS.layout: {tStS.layout}')
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# Compute nblk_pv for diagnostics
nblk_pv = cute.size(tOrP0, mode=[2])
nblk_qk = cute.size(tCrA, mode=[2])
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
# COMPREHENSIVE LAYOUT DUMP
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
print(f'[LAYOUT] tP.layout: {tP.layout}')
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
print(f'[DIAG] tP.layout: {tP.layout}')
print(f'[DIAG] tP.size: {cute.size(tP)}')
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# 3. Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# 4. Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: F32 → BF16
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,445 @@
"""
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
print(f'[DIAG] tStS.layout: {tStS.layout}')
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# Compute nblk_pv for diagnostics
nblk_pv = cute.size(tOrP0, mode=[2])
nblk_qk = cute.size(tCrA, mode=[2])
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
# COMPREHENSIVE LAYOUT DUMP
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
print(f'[LAYOUT] tP.layout: {tP.layout}')
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
print(f'[DIAG] tP.layout: {tP.layout}')
print(f'[DIAG] tP.size: {cute.size(tP)}')
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── Identity softmax: C-layout → A-layout transform ──
# 1. LOAD pipeline (reads scores from QK C-layout)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
# 3. Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# 4. Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# 5. IDENTITY: F32 → BF16
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
# 6. Store into A-layout
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
# 7. Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

403
tests/test_stage_b_v8.py Normal file
View File

@@ -0,0 +1,403 @@
"""
Stage B v8: Fix the identity softmax by using the A-fragment layout
for the TMEM store target instead of the C-fragment composition.
The bug in v7: tStS_P uses composition(tStS.layout, (128, tilePlikeFP32))
which gives layout (128,64):(65536,1) — C-fragment strides.
But the PV MMA reads from TMEM using the A-fragment layout
((128,16),1,4):((64,1),0,16) — physical TMEM strides.
For the store to be read correctly by the PV MMA, the store target
must use the same physical TMEM addressing as the A-fragment.
Key insight from CUTLASS source:
Physical TMEM for M=128, BK=64, BF16 A from TMEM (K-major):
tmem[dp=m, col=base_col + 16*mma_k + k_inner] for mma_k in 0..3, k_inner in 0..15
This means: A-fragment address = 64*m + k_inner + 16*mma_k
C-fragment address for (m, col) = ??? (virtual layout, not physical)
The St32x32b copy atom with tStS_P (C-composition) writes to C-layout addresses.
The PV MMA reads from A-layout addresses. These are different physical locations.
Fix: Use tP (from p_tmem_s, the A-fragment source layout) as the store target
instead of tStS_P (the C-fragment composition). This ensures the store writes
to physical TMEM addresses that the PV MMA's A-fragment will read correctly.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmaxV8:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
# TMEM tensors
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P tensor for PV MMA A-fragment
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# ── KEY FIX: Store target uses A-fragment layout (tP) not C-fragment composition ──
# The store must write to physical TMEM addresses that the PV MMA reads via A-fragment.
# tP has layout ((128,16),1,4,1):((64,1),0,16,0) — the A-fragment's physical TMEM layout.
# We need the store target at tmem_p0_offset = 32 columns into the S region.
# tP's iterator starts at tStS.iterator (base of S region).
# tOrP0 starts at tStS.iterator + 2*32 (scaled by F32/BF16 width ratio for the A-fragment).
# The store target should use the SAME layout as tP but with the p0 offset applied.
tP_store = cute.make_tensor(
tStS.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
p_tmem_s.outer)
print(f'[v8] tP.layout: {tP.layout}')
print(f'[v8] tP_store.layout: {tP_store.layout}')
print(f'[v8] tOrP0.layout: {tOrP0.layout}')
print(f'[v8] tStS.layout: {tStS.layout}')
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── LOAD from C-layout (reading QK scores) ──
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# ── STORE to A-layout (writing P for PV MMA) ──
# Use tP_store (A-fragment physical layout) as the store target
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tP_store)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tP_store)
# Identity tensor for the A-layout shape
cS_P = cute.make_identity_tensor((128, self.qk_mma_tiler[1])) # (M, K) for A-layout
# Wait, A-layout is (M, K) but the partition is different...
# make_fragment_A uses p_tmem_s which is an smem_layout_a — that's (M, K, L) shape
# tP has shape ((128,16),1,4,1) — the 4 is MMA_K, 16 is K_inner
# The identity tensor for partition_S should match tP's logical shape
# But tP is a TMEM tensor, not SMEM. The partition_S for the store uses tP_store
# as the destination. We need an identity tensor for the source.
# Actually in fmha, partition_S uses tScS_P (composed from C-fragment identity).
# Let me try the same approach: partition the store's source side with an
# identity tensor that matches the LOAD output (C-layout).
# The store takes register data (from the load) and writes to TMEM (A-layout).
# The register data has C-layout ordering (from the load).
# The store target has A-layout addresses (tP_store).
# We need the source partition of the store to match the register layout.
# Since the register layout came from the load's destination (tTMEM_LOADcS),
# and the store's source partition should match...
# Actually, the correct approach from fmha:
# thr_store.partition_S maps the source (register) identity tensor
# thr_store.partition_D maps the destination (TMEM) tensor
# The copy operation copies from register[partition_S] to TMEM[partition_D]
# The register layout must match what we computed (from the load)
# fmha uses tScS_P (composition of C-fragment identity) for partition_S
# But if our store target is A-layout, the partition_S should use A-layout identity
# Let me just try using the A-layout identity tensor
# tP_store has shape ((128,16),1,4,1) — need identity of (128, 64) for (M, K)
# Actually tP's outer shape is ((128,16),1,4,1):((64,1),0,16,0)
# The logical (M, K) identity: cS_A = make_identity_tensor((128, 64))
cS_A = cute.make_identity_tensor((128, 64))
tScS_A = tiled_tmem_store.get_slice(sfw_idx).partition_S(cS_A)
# Wait for scores
si_handle = mma_si_cons.wait_and_advance()
# Load from C-layout
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Identity: F32 -> BF16
tTMEM_STORErP = cute.make_rmem_tensor(tScS_A.shape, self.qk_acc_dtype)
tTMEM_STORErP_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErP.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
s_vec = tTMEM_LOADrS.load()
tTMEM_STORErP_e.store(s_vec.to(self.q_dtype))
# Store to A-layout TMEM
cute.copy(tiled_tmem_store, tTMEM_STORErP, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmaxV8(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v8: (Q @ K^T) @ V with identity softmax (A-layout store)')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

354
tests/test_stage_b_v8b.py Normal file
View File

@@ -0,0 +1,354 @@
"""
Stage B v8b: BF16 store directly to tOrP0 (diagnostic)
Key fixes over v6:
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
- JIT-time diagnostic prints of all TMEM sizes
Architecture (matches fmha.py exactly):
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBIdentitySoftmax:
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1],
self.qk_mma_tiler[2],
)
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
# tilePlikeFP32 for the store-side composition
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
# ── TMEM LAYOUT (matching fmha.py) ──
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
# in the same TMEM region. The A-layout view starts partway through the S region.
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
# The P offset of 32 means the A-layout starts at column 32 within the S region.
# This is because the C-layout and A-layout partition TMEM differently per-thread;
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
#
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32 # Same as fmha.py
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
# Also compute via get_num_tmem_alloc_cols for the full allocation
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
# V shares the same SMEM as B (same data, different layout for PV MMA)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# ── DIAGNOSTIC: BF16 store directly to tOrP0 ──
# Skip C→A transform entirely. Write known BF16 values to the exact
# TMEM addresses MMA2 will read as P (the A-fragment).
# This tests: does writing to the A-fragment's TMEM addresses work?
# Create a BF16 store atom matching the A-fragment type
tmem_store_atom_bf16 = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype)
tiled_tmem_store_bf16 = tcgen05.make_tmem_copy(tmem_store_atom_bf16, tOrP0)
thr_store_bf16 = tiled_tmem_store_bf16.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store_bf16.partition_D(tOrP0)
# Wait for MMA1 scores (MMA1 must complete before we touch TMEM)
si_handle = mma_si_cons.wait_and_advance()
# Fill register buffer with 1.0 in BF16, store to A-fragment TMEM
tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_STOREtP.shape, self.q_dtype)
for j in cutlass.range(cute.size(tTMEM_STORErP), unroll_full=True):
tTMEM_STORErP[j] = self.q_dtype(1.0)
cute.copy(tiled_tmem_store_bf16, tTMEM_STORErP, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# Release back to MMA warp
si_handle.release()
# ── Epilogue ──
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = torch.ones(128, 128, dtype=torch.float32, device="cuda") @ kvf # P=all-ones @ V
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Stage B v8b: BF16 store to tOrP0, P=all-ones')
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

348
tests/test_stage_b_v9.py Normal file
View File

@@ -0,0 +1,348 @@
"""
Stage B v9: Identity softmax matching fmha.py softmax_step EXACTLY.
Line by line match of the fmha softmax_step, but with identity softmax:
- No masking
- scale = 1 (no log2 scaling)
- row_max = 0 (skip max, just do exp(x) = 1 for identity)
- Actually for identity softmax, P = S (scores). So we just ld S, st as P.
- But we need to go through the C→A layout transform properly.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StageBv9:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2
self.acc_dtype = Float32; self.epilog_sync_bar_id = 1
self.use_2cta_instrs = False
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = 128
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.c_layout = c_layout
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = s_cols # 128
self.tmem_alloc_cols = s_cols + o_cols # 256
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
cute.nvgpu.OperandMajorMode.K,
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
# ── TMEM copy atoms (matching fmha softmax_step exactly) ──
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_tmem_load = tiled_tmem_load.get_slice(tidx % (32 * len(self.epilogue_warp_id)))
tTMEM_LOADtS = thr_tmem_load.partition_S(tStS0)
# Store target: composition of C-fragment (matching fmha)
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset,
cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32))))
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_tmem_store = tiled_tmem_store.get_slice(tidx % (32 * len(self.epilogue_warp_id)))
tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P)
# Identity tensors
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tScS_P = cute.make_tensor(tScS.iterator,
cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32))))
tTMEM_LOADcS = thr_tmem_load.partition_D(tScS)
tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P)
print(f'[v9] tTMEM_LOADcS.shape: {tTMEM_LOADcS.shape}')
print(f'[v9] tTMEM_STOREcS.shape: {tTMEM_STOREcS.shape}')
print(f'[v9] LOAD size: {cute.size(tTMEM_LOADcS)}')
print(f'[v9] STORE size: {cute.size(tTMEM_STOREcS)}')
print(f'[v9] tilePlikeFP32: {self.tilePlikeFP32}')
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX / EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# Load scores (matching fmha exactly)
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
# Identity softmax: P = S (no transform, just copy C→A)
# Match fmha pattern: x4 register, BF16 view, .load()/.store()
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout,
)
# For identity softmax: exp(0) = 1 for all, so P[i] = exp(S[i] - max) = 1
# But identity means P = S. So just copy: FP32→BF16, no math.
# Use the fmha fragment pattern
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T @ kvf
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StageBv9(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Stage B v9 (fmha-matched identity softmax): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

285
tests/test_store_verify.py Normal file
View File

@@ -0,0 +1,285 @@
"""
Test: Q@K^T → TMEM (scores), ld scores, st to P region,
then epilogue reads P region as C-fragment (not PV MMA).
If the ld/st roundtrip preserves the data, the epilogue should
output the same as Stage A (Q@K^T result).
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StoreVerify:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR
self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = s_cols # Only need scores region, no O region
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
# TMEM copy atoms for ld/st
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
thr_tmem_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_tmem_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_tmem_load.partition_D(tScS)
# Store target: P region (composition of C-fragment, offset 32)
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset,
cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32))))
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_tmem_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_tmem_store.partition_D(tStS_P)
tScS_P = cute.make_tensor(tScS.iterator,
cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32))))
tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX WARPS: ld from S, st to P, then epilogue reads P ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# Load scores from S region
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
# Identity: copy FP32→BF16, matching fmha pattern
tTMEM_STORErP_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErP_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErP_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErP_x4_e_frg = cute.logical_divide(
tTMEM_STORErP_x4_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErP_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, tTMEM_STORErP_x4, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue: read from P region (offset 32) instead of S region (offset 0)
# This tests if the store wrote to the correct physical TMEM locations
# that the C-fragment epilogue can read back
tCtP_base = cute.make_tensor(tmem_ptr + self.tmem_p0_offset, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtP_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T # Just Q@K^T — the ld/st roundtrip should produce this
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StoreVerify(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Store verify (ld S → st P → epilogue P): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

272
tests/test_store_verify2.py Normal file
View File

@@ -0,0 +1,272 @@
"""Store verify v2: ld S (full 128x128), st P (full 128x128 at offset 128),
then epilogue reads P region. No subview, no composition."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class StoreVerify2:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = s_cols * 2 # Two full regions
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # offset 0
tStS1 = cute.make_tensor(tStS.iterator + s_cols, tStS.layout) # offset 128 (second region)
# Load from S0, Store to S1 (same layout, just different offset)
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS1)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS1 = thr_store.partition_D(tStS1)
tTMEM_STOREcS = thr_store.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── SOFTMAX WARPS: ld from S0, st to S1, epilogue reads S1 ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# Load from S0
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
# Identity: FP32→BF16→FP32 via recast
tTMEM_STORErS1_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS1_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS1_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErS1_x4_e_frg = cute.logical_divide(
tTMEM_STORErS1_x4_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErS1_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
# Store to S1
cute.copy(tiled_tmem_store, tTMEM_STORErS1_x4, tTMEM_STOREtS1)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue reads from S1 (offset 128)
tCtS1_base = cute.make_tensor(tmem_ptr + s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = StoreVerify2(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Store verify v2 (ld S0, st S1, epi S1): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,271 @@
"""Minimal TMEM ld→st roundtrip. FP32 only, no BF16 cast.
Uses the fmha pattern: load to register, store via make_rmem_tensor + recast + load/store."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class TMEMCopyRoundtrip:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
# Copy atoms
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS0 = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS1)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS1 = thr_store.partition_D(tStS1)
tTMEM_STOREcS = thr_store.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ── TMA WARP ──
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ── MMA WARP ──
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ── EPILOGUE WARPS ──
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# Load from S0
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS0, tTMEM_LOADrS)
# Store to S1 using fmha recast pattern (but FP32→FP32, so identity)
tTMEM_STORErS1 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
# BF16 view of store register, with load layout (fp32→bf16 packing)
tTMEM_STORErS1_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS1.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErS1_e_frg = cute.logical_divide(
tTMEM_STORErS1_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErS1_e_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, tTMEM_STORErS1, tTMEM_STOREtS1)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Epilogue reads from S1
tCtS1_base = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref = qf @ kvf.T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = TMEMCopyRoundtrip(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('TMEM copy roundtrip (FP32 ld→BF16 recast→st→epi): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,255 @@
"""Minimal: Q@K^T → TMEM, ld FP32 from S0, st FP32 to S1, epi reads S1.
NO bf16 cast at all. Pure FP32 ld→st roundtrip."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class FP32Roundtrip:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
# Single tiled copy that can both ld and st
tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st_atom, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tTMEM_LDtS = thr_ld.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LDcS = thr_ld.partition_D(tScS)
tTMEM_STtS = thr_st.partition_D(tStS1)
tTMEM_STcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld → register → FP32 st (NO bf16)
rS = cute.make_rmem_tensor(tTMEM_LDcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tTMEM_LDtS, rS)
# FP32 → FP32, same identity tensor, same layout
# Use the fmha x4 pattern but with FP32→BF16→FP32 packing
rS1 = cute.make_rmem_tensor(tTMEM_STcS.shape, self.qk_acc_dtype)
rS1_e = cute.make_tensor(cute.recast_ptr(rS1.iterator, dtype=self.q_dtype), rS.layout)
frg_cnt = 4
frg_tile = cute.size(rS) // frg_cnt
rS_frg = cute.logical_divide(rS, cute.make_layout(frg_tile))
rS1_e_frg = cute.logical_divide(rS1_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
v = rS_frg[None, j].load()
rS1_e_frg[None, j].store(v.to(self.q_dtype))
cute.copy(tiled_st, rS1, tTMEM_STtS)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S1
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FP32Roundtrip(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('FP32 ld→BF16 recast→st roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,168 @@
"""
Diagnostic: understand the C-layout vs A-layout TMEM address mapping.
After Stage A writes Q@K^T results to TMEM via C-fragment (tStS0),
read the same TMEM data back using the A-fragment (tOrP0) layout
and compare against reading it via the C-fragment layout.
This will tell us whether the composition-based store target
produces the right physical TMEM addresses for the A-fragment to read.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
ref_qk = qf @ kvf.T # (128, 128)
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
c_layout = LayoutEnum.from_tensor(mC)
mma_tiler_mn = (128, 128)
mma_tiler = (*mma_tiler_mn, 1)
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, b_major,
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
pv_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
# Build the fragments
qk_thr = qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
tStS = qk_thr.make_fragment_C(qk_acc_shape)
pv_thr = pv_mma.get_slice(0)
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
tilePlikeFP32 = qk_mma_tiler[1] * BFloat16.width // 32
tStS_P = cute.make_tensor(tStS.iterator + 32, cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))))
# Print the key layouts
print(f'tStS.layout: {tStS.layout}')
print(f'tStS.size: {cute.size(tStS)}')
print(f'tOrP.layout: {tOrP.layout}')
print(f'tOrP.size: {cute.size(tOrP)}')
print(f'tStS_P.layout: {tStS_P.layout}')
print(f'tP.layout: {tP.layout}')
print(f'tilePlikeFP32: {tilePlikeFP32}')
# For the C-fragment layout ((128,128),1,1):((65536,1),0,0)
# element at logical (m, n) maps to address 65536*m + n
# For the A-fragment layout ((128,16),1,4):((64,1),0,16)
# element at logical ((m_inner, k_inner), 1, mma_k) maps to address 64*m_inner + k_inner + 16*mma_k
# In the C-fragment, logical (m, col) = (m, col) with stride (65536, 1)
# In the A-fragment, logical ((m, k16), 1, block4) with stride ((64, 1), 0, 16)
# So A[m, k16, block4] -> address 64*m + k16 + 16*block4
# C[m, col] -> address 65536*m + col
# The KEY question: for the same physical TMEM column c and row m,
# what does C-layout say the logical index is, vs A-layout?
# C-layout: (m, col) -> addr = 65536*m + col
# A-layout: ((m, k16), 1, blk) -> addr = 64*m + k16 + 16*blk
# If both refer to the same physical TMEM, then:
# 65536*m_c + col = 64*m_a + k16 + 16*blk
# These can only be equal if the address spaces are different (C uses a different base)
# OR if the C-layout addresses wrap around modulo the TMEM size
# C-fragment cosize = 8323200 which is >> 16384 (actual data size)
# The C-fragment uses a STRIDED layout where stride-0 (M) = 65536
# This means the C-fragment is NOT a dense layout in TMEM address space
# The actual TMEM addresses used by the MMA hardware follow a different mapping
# Let's check: for M=128, N=128 C-fragment with ((128,128),1,1):((65536,1),0,0)
# The max address = 65536*127 + 127 = 8,321,023
# But TMEM only has 1MB = 256 columns * 4096 rows
# So addresses 65536*m must be modded or the layout is virtual
# AH. The C-fragment layout is VIRTUAL. cute.compile maps it to physical TMEM
# addresses when the actual MMA operation runs. The strides in the fragment layout
# are logical, not physical. The MMA hardware knows the physical column mapping.
# For the store, make_tmem_copy(St32x32bOp, tStS_P) creates a copy plan.
# The copy writes to tStS_P's layout addresses. These ARE the physical TMEM addresses
# that the MMA hardware will read from for the A-fragment.
# So the question becomes: is the COMPOSITION tStS_P correct?
# tStS.layout = ((128,128),1,1):((65536,1),0,0)
# composition with (128, 64) produces (128,64):(65536,1)
# This takes the first 64 columns of the C-layout.
# tOrP.layout = ((128,16),1,4):((64,1),0,16)
# This is the A-layout for 4 K=16 blocks.
# Let me compute what TMEM addresses each layout generates for the same (m, k) pair.
# For m in [0..127], for k in [0..63]:
# C-layout: logical (m, k) -> addr = 65536*m + k
# A-layout: k = 16*blk + k_inner
# logical ((m, k_inner), 1, blk) -> addr = 64*m + k_inner + 16*blk = 64*m + k
# So C-layout says addr = 65536*m + k
# A-layout says addr = 64*m + k
# These are DIFFERENT for m > 0.
# The C-fragment stride of 65536 is NOT the physical TMEM stride.
# Physical TMEM for M=128 has stride 64 (one 32B row per datapath, 64 FP32 columns)
# Wait, 128 BF16 = 256 bytes = 8 columns of 32B? No...
# TMEM is organized as columns of 128 rows of 32 bits each.
# For FP32 accumulator: each column holds 128 FP32 values (one per DP lane)
# For BF16 view: each FP32 column = 2 BF16 values packed? Or separate columns?
# The 64 stride in A-layout: 128 BF16 values need 64 FP32 columns (2 BF16 per FP32)
# stride-1 in the (128,16) sub-shape: (64,1) means m has stride 64, k_inner has stride 1
# 64 FP32 columns for 128 BF16 rows makes sense: each column holds 2 BF16, so 128/2=64 columns
# But wait, TMEM is 32-bit per entry per DP lane. For FP32, one column = 128 FP32 values.
# For BF16 A-fragment, each column stores 2 BF16 (high+low), so 128 BF16 = 64 columns.
# That matches the stride of 64 for m in the A-layout.
# For the C-fragment with FP32 accumulator, each column = 128 FP32 values.
# The C-layout for (128,128) FP32: stride (65536, 1)
# 65536 = 512 * 128. That's weird.
# Actually 65536 = 2^16. For 128 rows in FP32, each row in a column needs... hmm.
# The C-fragment stores FP32 values. stride-0 = 65536 means elements in the M direction
# are 65536 addresses apart. But we have 128*128 = 16384 elements total.
# cosize = 8323200 >> 16384. The layout is very sparse.
# I think the C-fragment layout addresses are NOT physical TMEM addresses.
# The MMA hardware and the ld/st copy atoms know how to map between logical and physical.
# The copy atom's partition_S/partition_D functions create the right physical mappings.
# The real question for Stage B is simpler:
# After Q@K^T writes to tStS0 (C-fragment), and we ld the scores,
# apply identity softmax, and st them to tStS_P (composed C-fragment)...
# does the PV MMA read the same data from tOrP0 (A-fragment)?
# fmha.py proves this works for F16. The composition pattern is identical.
# So either: (1) something about BF16 vs F16, or (2) something else in my code is wrong.
print('\nDiagnostic complete. Need to check if BF16 vs F16 affects the TMEM layout mapping.')
if __name__ == '__main__':
test()

View File

@@ -0,0 +1,237 @@
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
class PureFP32Copy:
def __init__(self, mma_tiler_mn):
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
self.c_dtype = BFloat16; self.acc_dtype = Float32
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2; self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
def _setup(self, qk_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
self.num_ab_stage = 1; self.num_acc_stage = 1
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
self.s_cols = find_tmem_tensor_col_offset(tStS)
self.tmem_alloc_cols = self.s_cols * 2
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype,
LayoutEnum.from_tensor(a).mma_major_mode(),
LayoutEnum.from_tensor(b).mma_major_mode(),
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
a_smem_s, b_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gA, mode=[3])
qk_thr = qk_mma.get_slice(0)
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
# LD and ST on same layout
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
sfw = tidx % (32 * len(self.epilogue_warp_id))
thr_ld = tiled_ld.get_slice(sfw)
thr_st = tiled_st.get_slice(sfw)
tLdS = thr_ld.partition_S(tStS0)
tStS = thr_st.partition_D(tStS1)
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS_id)
tLdcS = thr_ld.partition_D(tScS)
tStcS = thr_st.partition_S(tScS)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrA, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
if warp_idx < self.mma_warp_id:
tmem.allocate(self.tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
si_handle = mma_si_cons.wait_and_advance()
# FP32 ld → FP32 st, NO recast
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
cute.copy(tiled_ld, tLdS, rLd)
cute.arch.fence_view_async_tmem_load()
# Direct copy: ld register → st register (same shape since same layout)
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
# Since ld and st have the same C-fragment layout and same identity tensor,
# the register shapes should match. Copy element by element.
for i in cutlass.range(cute.size(rLd), vectorize=True):
rSt[i] = rLd[i]
cute.copy(tiled_st, rSt, tStS)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# epi reads S1
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, k = 128, 128, 128
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
ref = q[:,:,0].float() @ kv[:,:,0].float().T
import cutlass.torch as ct
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = PureFP32Copy(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('Pure FP32 ld→st copy roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()