Stage B: pipeline deadlock fixed, V MN-major applied, PV output garbage
Pipeline deadlock fixed: - No cta_layout_vmnk on mma_si PipelineUmmaAsync - TMA warp excluded from tmem.wait_for_alloc - PipelineTmaStore (not TmaStorePipeline) Bug 1 (V MN-major): fix applied - PV MMA uses v_major=OperandMajorMode.MN - V shaped (64,128) strides(1,64) via as_strided Bug 2 (softmax packing): C-fragment composition store applied - FP32 to BF16 packing works - St32x32bOp uses Float32 (not BFloat16) Bug 3 (PV garbage): investigating - PV MMA cosine ~0.01 against reference - Suspected TMEM layout mismatch between softmax P store and PV A-fragment read Test results: - test_mma_si_only: cosine 0.999999 PASS - test_mma_si_pv: cosine 0.01 FAIL (pipeline works, PV output wrong)
This commit is contained in:
300
README.md
300
README.md
@@ -2,132 +2,105 @@
|
||||
|
||||
CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use `cutlass.cute` (CuTeDSL) with Blackwell tensor cores.
|
||||
|
||||
## File Map
|
||||
## Status (May 21, 2026 — 04:10 UTC)
|
||||
|
||||
```
|
||||
cutedsl/
|
||||
├── native_swa_decode.py # SWA decode attention — IN PROGRESS (v3 tcgen05 rewrite)
|
||||
├── native_sparse_decode.py # Sparse (CSA/HCA) decode — NOT YET REWRITTEN
|
||||
├── nvfp4_cutedsl.py # NVFP4 MoE runner (CuTeDSL) — WORKING
|
||||
├── moe_pipeline.py # MoE fused SwiGLU pipeline — WORKING
|
||||
├── blackwell_attention.py # vLLM bridge for Blackwell attention path
|
||||
├── csa_attention.py # CSA/HCA sparse attention bridge
|
||||
├── custom_ops.py # Custom CUDA ops registration
|
||||
└── kernel/
|
||||
└── blockscaled_gemm/
|
||||
└── dense_blockscaled_gemm_persistent.py # REFERENCE: Blackwell TMEM/tcgen05 GEMM
|
||||
|
||||
tests/
|
||||
├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM
|
||||
├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + C-fragment softmax (runs, wrong output)
|
||||
├── test_stage_b_afrag2.py # 🔨 Stage B: A-fragment store pattern (compiles, wrong output)
|
||||
├── test_tmem_pure_fp32.py # ✅ FP32 ld→st roundtrip on C-fragment: cosine 0.999999
|
||||
├── test_bf16_elemwise.py # ✅ FP32→BF16→FP32 elemwise + FP32 st: cosine 0.999999
|
||||
├── test_recast_minimal.py # ✅ BF16 recast ld S0→st S1 via C-fragment: cosine 0.999999
|
||||
├── test_bf16_recast_simple.py # ❌ BF16 recast ld/st same region (S0): zero (can't overwrite MMA output)
|
||||
├── test_tmem_copy_roundtrip.py # ❌ BF16 recast + C→A mismatch: zero
|
||||
├── test_stage_b_final.py # ❌ C-fragment st + A-fragment read: NaN (physical layout mismatch)
|
||||
├── test_afrag_roundtrip.py # ❌ A-frag st corrupts S0 (overlapping TMEM region)
|
||||
├── diag_tmem.py # Diagnostic: TMEM layout inspection
|
||||
└── ...
|
||||
```
|
||||
|
||||
## Current Status
|
||||
|
||||
### ✅ Stage A: Bare Q@K^T via tcgen05.mma — COMPLETE (May 20)
|
||||
### ✅ Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM — COMPLETE
|
||||
|
||||
**File**: `tests/test_stage_a_v2.py`
|
||||
**Result**: Q(128,128) @ K^T(128,128) → S(128,128), cosine 0.999999
|
||||
|
||||
Validates the full tcgen05.mma → TMEM → epilogue → GMEM path:
|
||||
- tcgen05.mma with BF16 inputs, FP32 TMEM accumulator
|
||||
- TMA load for A and B (cute.nvgpu.make_tiled_tma_atom_A/B)
|
||||
- TMA store for C (cpasync.CopyBulkTensorTileS2GOp)
|
||||
- Warp specialization: 4 epilogue warps + 1 MMA warp + 1 TMA warp = 192 threads
|
||||
- PipelineTmaUmma for AB pipeline, PipelineUmmaAsync for acc pipeline
|
||||
- TmemAllocator for TMEM allocation/deallocation
|
||||
- utils.gemm.sm100.epilogue_tma_store for the TMEM→reg→SMEM→TMA→GMEM epilogue
|
||||
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS
|
||||
|
||||
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20-21)
|
||||
**Pipeline deadlock: FIXED. Kernel runs without deadlock.**
|
||||
**Bug 1 (V MN-major): Fix applied.**
|
||||
**Bug 2 (softmax packing): Fix applied, but PV output is garbage.**
|
||||
|
||||
**Core Problem**: The C-fragment (MMA accumulator) and A-fragment (MMA A-operand from TMEM) use **different physical TMEM address mappings** for the same logical (M,K) position. The softmax writes P via one mapping, but the PV MMA reads via the other. This produces garbage.
|
||||
#### Bug 1: V B-Operand Must Be MN-Major — ✅ FIX APPLIED
|
||||
|
||||
#### What's Been Proven
|
||||
V must be shaped (head_dim, seq) = (64, 128) with strides (1, 64) — MN-major.
|
||||
PV MMA uses `v_major` (OperandMajorMode.MN) instead of `b_major` (K).
|
||||
|
||||
| Test | Pattern | Result | Why |
|
||||
|------|---------|--------|-----|
|
||||
| test_tmem_pure_fp32 | FP32 ld→st, same C-fragment layout | ✅ cos=0.999999 | C-fragment addresses self-consistent |
|
||||
| test_bf16_elemwise | FP32→BF16→FP32 elemwise, C-fragment st | ✅ cos=0.999999 | BF16 conversion works, C-fragment st works |
|
||||
| test_recast_minimal | BF16 recast ld S0→st S1, C-fragment | ✅ cos=0.999999 | Recast works when writing to different region |
|
||||
| test_bf16_recast_simple | BF16 recast ld/st same region S0 | ❌ zero | Can't overwrite MMA output in same region |
|
||||
| test_stage_b_final | C-fragment st → A-fragment read (S1) | ❌ NaN | C-layout ≠ A-layout physical addresses |
|
||||
| test_stage_b_afrag2 | A-fragment st (backward FMHA pattern) | ❌ cos=-0.02 | Store + PV MMA layout compatible, but register data flow wrong |
|
||||
V must use `as_strided` — default PyTorch (64,128) gives strides (128,1) which is K-major.
|
||||
|
||||
#### Root Cause: C-fragment vs A-fragment Physical TMEM Layout
|
||||
#### Bug 2 (Packing): C-Fragment Composition Store — ✅ APPLIED, ❌ PV OUTPUT WRONG
|
||||
|
||||
From the CUTLASS source (`mma_traits_sm100.hpp`):
|
||||
FP32→BF16 packing via C-fragment composition store (FMHA pattern) runs without error.
|
||||
The softmax packing overwrites part of S in TMEM (P at tmem_p0_offset=32 overlaps S at offset 0).
|
||||
This is intentional — S is no longer needed after softmax.
|
||||
|
||||
**C-fragment (MMA accumulator, FP32):**
|
||||
- Layout: `((128,128),1,1):((65536,1),0,0)` — **virtual** layout
|
||||
- Physical TMEM addresses determined by the MMA hardware's accumulator write path
|
||||
- St32x32bOp with C-fragment layout writes to C-fragment physical addresses
|
||||
⛔ **FOOTGUN**: `St32x32bOp` MUST use Float32, NOT BFloat16.
|
||||
⚠️ The recast view for P packing uses the LOAD layout (128 BF16 elements), not the store composition shape.
|
||||
|
||||
**A-fragment (MMA A-operand from TMEM, BF16, K-major, M=128):**
|
||||
- Layout: `((128,16),1,4):((65536,1),0,16)` — **physical** TMEM layout
|
||||
- A[m, k_inner] → `tmem[dp=m, col=base + 16*mma_k + k_inner]`
|
||||
- BK=64 = 4 K=16 MMA atoms, NOT one K=64 atom
|
||||
- The 4D fragment partition order is NOT the physical TMEM order
|
||||
#### Bug 3 (NEW): PV MMA Output Is Garbage — 🔨 INVESTIGATING
|
||||
|
||||
**The St32x32bOp with C-fragment composition writes to C-layout physical addresses. The PV MMA reads from A-layout physical addresses. These are different physical locations.**
|
||||
|
||||
#### Forward FMHA's Approach (FP16 Only!)
|
||||
|
||||
Forward FMHA uses a recast pattern to pack 2×FP16 into 1×FP32 register, then St32x32bOp writes to a C-fragment composition subview. **But forward FMHA explicitly rejects BF16:**
|
||||
```python
|
||||
if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}:
|
||||
raise ValueError(in_dtype must be Float8E4M3FN or Float16)
|
||||
```
|
||||
The recast softmax path is validated for FP16, NOT BF16. Our BF16 use is outside the tested path.
|
||||
|
||||
#### Backward FMHA's Approach (BF16 Supported)
|
||||
|
||||
Backward FMHA writes dV to TMEM using the A-fragment layout:
|
||||
1. `tdVrP_iter = cute.recast_ptr(tSTtST.iterator, dtype=self.element_dtype)` — recast C-fragment iterator to BF16
|
||||
2. `tdVrP = cute.make_tensor(tdVrP_iter, tOrP.layout)` — A-fragment layout, C-fragment base
|
||||
3. `tmem_store_atom = cute.make_copy_atom(St32x32bOp(Repetition(8)), self.element_dtype)` — BF16 store atom
|
||||
4. Quantize via `make_rmem_tensor(input.shape, element_dtype)` + `.load()/.store(v.to(element_dtype))` — true BF16 register, NOT recast
|
||||
5. Reshape: `cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape))` — match store partition shape
|
||||
|
||||
This compiles and runs for us (no crash), but the output is still wrong (cosine -0.02). The remaining issue is the **register layout mismatch**:
|
||||
- Load partition (C-fragment): 128 FP32 values per thread (full 128×128 QK tile)
|
||||
- Store partition (A-fragment): 64 BF16 values per thread (128×64 P tile for PV MMA K=64)
|
||||
- The backward FMHA uses `quantize()` + reshape, but our element counts differ because the QK tile is 128×128 while P only needs 128×64
|
||||
|
||||
#### Next Steps for Stage B
|
||||
|
||||
1. **Fix the register data flow** — properly subselect the P-relevant 64 BF16 columns from the 128 FP32 load columns, or use the backward FMHA's PdO MMA tiler (M=128, N=64) instead of (M=128, N=128)
|
||||
2. **Verify A-fragment store roundtrip** — write known BF16 values via A-fragment store, have PV MMA read them back via A-fragment, confirm the physical TMEM addresses match
|
||||
3. **Once data flow is correct, add online softmax** (Stage C)
|
||||
The PV MMA produces cosine ~0.01 against the reference. Suspected cause: TMEM layout mismatch between the softmax P store (C-fragment composition layout) and the PV MMA A-fragment read (`p_tmem_s` layout from `make_smem_layout_a`). These should alias the same physical TMEM columns by the sequential-flattening property, but the specific layout functions may compute different shapes/strides.
|
||||
|
||||
### 🔨 Stage C: Online Softmax — AFTER B
|
||||
|
||||
The hard part. Per the pseudocode:
|
||||
- Epilogue warps tcgen05.ld scores from TMEM into register fragments
|
||||
- Compute per-row: tile_max, new_max, rescale = exp(old_max - new_max)
|
||||
- Apply rescale to tmem_output in place (tmem_output *= rescale)
|
||||
- Compute exp(scores - new_max), tcgen05.st back to TMEM as P operand for MMA2
|
||||
- Update row_sum = row_sum * rescale + new_tile_sum
|
||||
|
||||
**The register fragment layout from tcgen05.ld is NOT (row, col).** It's determined by the MMA instruction's partition of the accumulator. Need to figure out the mapping from fragment indices to logical (head, kv_pos) positions for per-row softmax operations. fmha.py uses `tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)` for the row max — a built-in reduction that handles the layout.
|
||||
Per the pseudocode: epilogue warps compute per-row tile_max, rescale, exp, store P back to TMEM.
|
||||
|
||||
### 🔨 Stage D: FP8 Paged KV Gather — AFTER C
|
||||
|
||||
Replace BF16 TMA load of KV with:
|
||||
- Indexed cp.async gather from paged KV cache (fp8)
|
||||
- Per-position dequant scale (inv_scale) applied during or after gather
|
||||
- Keep KV in fp8 in SMEM, let the MMA's per-row scale handle dequant (like blockscaled GEMM)
|
||||
Replace BF16 TMA load with FP8 paged KV gather + per-position dequant.
|
||||
|
||||
### Architecture: Per-Tile Flow (from /root/fragile-kernel-example/README.md)
|
||||
---
|
||||
|
||||
## Pipeline Deadlock — ✅ FIXED (May 21)
|
||||
|
||||
v20-v25 all deadlocked on GPU. Three root causes found and fixed:
|
||||
|
||||
### Fix 1: PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk
|
||||
|
||||
FMHA's mma_s0/mma_s1 PipelineUmmaAsync calls do NOT pass cta_layout_vmnk. Removing it fixes the deadlock.
|
||||
|
||||
### Fix 2: TMA Warp Must NOT Call tmem.wait_for_alloc()
|
||||
|
||||
The tmem allocation barrier has `num_threads = 32 * (mma_warp + epilogue_warps)`. The TMA warp is NOT part of this barrier. Calling `wait_for_alloc()` from the TMA warp corrupts the barrier.
|
||||
|
||||
### Fix 3: PipelineTmaStore (not TmaStorePipeline)
|
||||
|
||||
`pipeline.TmaStorePipeline` does not exist. The correct name is `pipeline.PipelineTmaStore`.
|
||||
|
||||
---
|
||||
|
||||
## ⛔ DEAD TEST: test_stage_b_v21.py — DELETED, DO NOT RECREATE
|
||||
|
||||
v21 attempted both Bug 1 and Bug 2 fixes in a hand-rolled pipeline kernel. It deadlocks on GPU. Root cause: pipeline synchronization mismatch. **Do not recreate.** Write from scratch using fmha.py as the reference.
|
||||
|
||||
---
|
||||
|
||||
## ⛔ FOOTGUNS — CUTLASS CuTeDSL Landmines
|
||||
|
||||
### 1. St32x32bOp with 16-bit dtype → ILLEGAL MEMORY ACCESS
|
||||
|
||||
`St32x32bOp(Repetition(N), BFloat16)` crashes at runtime. You MUST use `St32x32bOp(Repetition(N), Float32)` and pack 2×16-bit values into 1×Float32 backing words via `cute.recast_ptr`. The 16-bit type only appears in the recast view, never in the store atom itself.
|
||||
|
||||
### 2. V B-Operand Major Mode ≠ K Major Mode
|
||||
|
||||
FMHA requires `v_major_mode == OperandMajorMode.MN`. Passing K's K-major mode for V is WRONG. V must be shaped (head_dim, seq) with strides (1, head_dim) to produce MN-major. Standard PyTorch row-major (seq, head_dim) gives K-major.
|
||||
|
||||
### 3. CuTe Nested Layout Modes Flatten Sequentially
|
||||
|
||||
A layout like `((128,16),1,(4,2)):((65536,1),0,(16,64))` looks "non-sequential" but flattens to `addr = m*65536 + k` when k = k0 + 16*k1 + 64*k2 (CuTe row-major order). Do NOT assume nested modes imply non-sequential physical addressing. The C-fragment composition and A-fragment alias the same TMEM columns.
|
||||
|
||||
### 4. PipelineUmmaAsync Consumer Group = Thread Count, NOT Warp Count
|
||||
|
||||
```python
|
||||
# WRONG: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4)
|
||||
# CORRECT: consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(warp_ids))
|
||||
```
|
||||
|
||||
### 5. PipelineUmmaAsync for mma_si Must NOT Pass cta_layout_vmnk
|
||||
|
||||
Passing `cta_layout_vmnk` to the mma_si PipelineUmmaAsync causes deadlock. FMHA does not pass it. Remove it.
|
||||
|
||||
### 6. TMA Warp Must NOT Call tmem.wait_for_alloc()
|
||||
|
||||
The tmem allocation barrier only includes MMA + epilogue warps. The TMA warp is excluded. Calling `wait_for_alloc()` from the TMA warp corrupts the barrier.
|
||||
|
||||
---
|
||||
|
||||
## Architecture: Per-Tile Flow
|
||||
|
||||
```
|
||||
For each KV tile:
|
||||
@@ -138,107 +111,57 @@ For each KV tile:
|
||||
a. tcgen05.ld scores from TMEM → register fragments
|
||||
b. Compute tile_max, new_max, rescale = exp(old_max - new_max)
|
||||
c. Apply rescale to tmem_output IN PLACE (tmem_output *= rescale)
|
||||
d. tcgen05.st exp(scores - new_max) back to TMEM → now it's the P operand
|
||||
d. tcgen05.st exp(scores - new_max) back to TMEM → P operand (via C-fragment composition)
|
||||
e. Release mma_si (softmax_done — MMA warp can re-acquire and issue PV MMA)
|
||||
4. MMA warp waits on mma_si acquire (softmax done), then MMA2: P @ sKV[stage] → tmem_output (accumulate=True)
|
||||
4. MMA warp waits on mma_si acquire (softmax done), MMA2: P @ sV → tmem_output (accumulate=True)
|
||||
5. Stage released, load warp can refill it
|
||||
|
||||
After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast to BF16, store to GMEM
|
||||
```
|
||||
|
||||
### ✅ NVFP4 MoE (CuTeDSL) — WORKING
|
||||
- `nvfp4_cutedsl.py` + `moe_pipeline.py`
|
||||
- CuTeDSL NVFP4 Linear (q_a, kv, q_b, o_b) — cosine 0.994+
|
||||
- CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) — cosine 0.988
|
||||
- Fused SwiGLU epilogue (granularity-8 weight interleave) — cosine 0.988
|
||||
---
|
||||
|
||||
### ✅ FP8 KV Quantize/Dequant — WORKING
|
||||
- FP8 KV: cosine 0.9997
|
||||
- NVFP4 KV: cosine 0.9943 (2x smaller than FP8)
|
||||
- Paged KV cache read/write: cosine 1.0
|
||||
## Test Results
|
||||
|
||||
### ❌ Sparse Decode Attention — NOT YET REWRITTEN
|
||||
`native_sparse_decode.py` still has the scalar FMA bug. Needs the same tcgen05.mma rewrite.
|
||||
| File | Description | Cosine | Status |
|
||||
|------|-------------|--------|--------|
|
||||
| `test_stage_a_v2.py` | Q@K^T only | 0.999999 | ✅ PASS |
|
||||
| `test_mma_si_only.py` | Q@K^T + mma_si pipeline (no PV) | 0.999999 | ✅ PASS |
|
||||
| `test_softmax_only.py` | Q@K^T + softmax packing, output S | 0.52 | ❌ S overwritten by P (expected) |
|
||||
| `test_mma_si_pv.py` | Q@K^T + softmax + P@V (V MN-major) | 0.01 | ❌ PV output garbage |
|
||||
| `test_stage_b_v7.py` | Q@K^T + C-fragment softmax (V=K, wrong major) | -0.02 | ❌ wrong major + P packing |
|
||||
| `test_stage_b_v20.py` | Q@K^T + softmax (V=K, PipelineTmaStore bug) | N/A | ❌ compile error |
|
||||
|
||||
### ✅ Full Attention Pipeline (standalone tests) — WORKING
|
||||
- FP8 KV → full attention: cosine 0.9997
|
||||
- CSA sparse attention (cr=4): works
|
||||
- HCA sparse attention (cr=128): works
|
||||
- Merged CSA+SWA attention: works
|
||||
---
|
||||
|
||||
## Critical APIs & Lessons
|
||||
|
||||
### C-fragment ≠ A-fragment TMEM Physical Layout — THE MAY 20-21 FINDING
|
||||
|
||||
**The St32x32bOp with C-fragment composition writes to C-layout physical TMEM addresses. The PV MMA reads from A-layout physical TMEM addresses. These are DIFFERENT physical locations for the same logical (M,K) position.**
|
||||
|
||||
For the softmax to work, P must be written to TMEM using the A-fragment's physical layout, not the C-fragment's. The backward FMHA does this correctly by:
|
||||
1. Creating the store destination with A-fragment layout + recast C-fragment iterator
|
||||
2. Using a BF16 St32x32bOp atom
|
||||
3. True BF16 register (not FP32 recast) via quantize() pattern
|
||||
|
||||
### Forward FMHA Recast Pattern — FP16 ONLY
|
||||
|
||||
The `cute.recast_ptr` + `.store(v.to(FP16))` pattern for packing 2×16-bit into 1×FP32 register is validated for FP16 only. BF16 is rejected in forward FMHA. The BF16 recast produces zero output when writing to the same TMEM region as the MMA output, and NaN when writing to a different region read via A-fragment.
|
||||
|
||||
### PipelineUmmaAsync consumer group size — thread count, NOT warp count
|
||||
|
||||
```python
|
||||
# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) # warp count
|
||||
|
||||
# CORRECT (matches fmha.py):
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(softmax_warp_ids)) # thread count
|
||||
```
|
||||
|
||||
### TMEM offset arithmetic
|
||||
- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count
|
||||
- QK accumulator: 128 TMEM columns
|
||||
- A-fragment offset: `acc_dtype.width // q_dtype.width * tmem_p0_offset` (F32/BF16=2)
|
||||
|
||||
- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count (with 0x8000 tag for A-fragments)
|
||||
- QK accumulator C fragment: 128 TMEM columns
|
||||
- PV A-fragment: offset 0x8020 = tag(0x8000) + col(32) — the 0x8000 is a TMEM memory-space identifier
|
||||
- `tOrP0 = cute.make_tensor(tOrP.iterator + acc_dtype.width // q_dtype.width * tmem_p0_offset, tOrP.layout)` — A-fragment offset scaled by dtype width ratio (F32/BF16 = 2)
|
||||
|
||||
### A-fragment iterator must use recast C-fragment pointer
|
||||
|
||||
When creating the P tensor for PV MMA's A-operand, the iterator must be the C-fragment's iterator recast to BF16:
|
||||
### pv_mma_tiler — FMHA Convention
|
||||
```python
|
||||
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
```
|
||||
Without the recast, the A-fragment addresses are computed from an FP32 pointer base, giving wrong physical TMEM addresses (illegal memory access crash).
|
||||
|
||||
### V SMEM aliasing (K and V share SMEM)
|
||||
|
||||
```python
|
||||
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1])
|
||||
# = (M, head_dim, QK_N) = (128, 64, 128) for head_dim=64
|
||||
```
|
||||
|
||||
### `make_trivial_tiled_mma` has two overloads
|
||||
|
||||
### make_trivial_tiled_mma — Use New Overload
|
||||
```python
|
||||
# New (preferred):
|
||||
make_trivial_tiled_mma(a_dtype, b_dtype, a_leading_mode, b_leading_mode,
|
||||
acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)
|
||||
|
||||
# Deprecated (still works, used by Stage A):
|
||||
make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode,
|
||||
acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)
|
||||
```
|
||||
|
||||
### Other APIs discovered from Stage A
|
||||
### 3D tensors required
|
||||
Tensors must be 3D (M, K, L) for `cute.local_tile` — add L=1 dimension.
|
||||
|
||||
1. **`cute.Tensor` API** — `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)`
|
||||
2. **3D tensors** — Tensors must be 3D (M, K, L) for `cute.local_tile` — add L=1 dimension
|
||||
3. **`PipelineTmaUmma.create(...).make_participants()`** — returns `(producer, consumer)` pair
|
||||
4. **`utils.gemm.sm100.epilogue_tma_store`** — handles transform + partition/dcopy. DO NOT hand-roll.
|
||||
5. **`get_num_tmem_alloc_cols`** — correct TMEM allocation (accepts list of fragments, sums cols, rounds to power of 2)
|
||||
6. **`smem.allocate_tensor()`** — for SMEM tensors (not SharedStorage struct for A/B/C)
|
||||
7. **`LayoutEnum.from_tensor(a).mma_major_mode()`** — major mode from cute tensor
|
||||
8. **Minimum valid N tile for tcgen05.mma BF16**: 32 (step 32, range 32-256)
|
||||
### Other APIs
|
||||
1. `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)` — CuTe tensor from PyTorch
|
||||
2. `PipelineTmaUmma.create(...).make_participants()` — returns (producer, consumer) pair
|
||||
3. `utils.gemm.sm100.epilogue_tma_store` — handles transform + partition/dcopy. DO NOT hand-roll.
|
||||
4. `smem.allocate_tensor()` — for SMEM tensors
|
||||
5. `LayoutEnum.from_tensor(a).mma_major_mode()` — major mode from cute tensor
|
||||
|
||||
## Environment
|
||||
|
||||
@@ -247,15 +170,6 @@ make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode,
|
||||
- **PYTHONPATH**: `/root/dsv4-nvfp4-workspace/kernel`
|
||||
- **Model**: `/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4`
|
||||
- **vLLM repo**: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
|
||||
- **Pseudocode**: `/root/fragile-kernel-example/README.md` — authoritative per-tile attention flow
|
||||
- **Pseudocode**: `/root/fragile-kernel-example/README.md`
|
||||
- **fmha.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||||
- **fmha_bwd.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py`
|
||||
|
||||
## 4-Stage Build Plan
|
||||
|
||||
| Stage | Goal | Status |
|
||||
|-------|------|--------|
|
||||
| A | Bare Q@K^T via tcgen05.mma → TMEM → GMEM | ✅ COMPLETE |
|
||||
| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 A-fragment store compiles, register data flow needs fixing |
|
||||
| C | Online softmax between MMA1 and MMA2 (the hard part) | ⬜ TODO |
|
||||
| D | FP8 paged KV gather + dequant (replace BF16 TMA load) | ⬜ TODO |
|
||||
|
||||
247
tests/test_mma_si_only.py
Normal file
247
tests/test_mma_si_only.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Minimal test: Stage A + mma_si pipeline (no PV, no V).
|
||||
If this deadlocks, the mma_si pipeline is broken.
|
||||
If this passes, the deadlock is caused by adding V/PV.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class MmaSiTest:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, tiled_mma):
|
||||
mma_inst_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
||||
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_k * 4)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (tiled_mma.thr_id.shape,))
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1], self.mma_tiler[2])
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(tiled_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(tiled_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = a.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
tiled_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
self._setup(tiled_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(tiled_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, tiled_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(tiled_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] # ADDED: mma_si
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
# ADDED: mma_si pipeline (same as v27)
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
thr_mma = tiled_mma.get_slice(0)
|
||||
tCgA = thr_mma.partition_A(gA); tCgB = thr_mma.partition_B(gB); tCgC = thr_mma.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = tiled_mma.make_fragment_A(sA); tCrB = tiled_mma.make_fragment_B(sB)
|
||||
acc_shape = thr_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# TMA WARP
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# MMA WARP — same as Stage A but with mma_si pipeline added (just acquire/commit, no PV)
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
tCtAcc = tCtAcc_base[(None, None, None, 0)]
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
|
||||
# ADDED: mma_si acquire (just like v27)
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(tiled_mma, tCtAcc, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tCtAcc)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# ADDED: mma_si commit + second acquire (like v27)
|
||||
s0_handle.commit()
|
||||
s0_handle = mma_si_prod.acquire_and_advance() # wait for "softmax"
|
||||
# In real use, softmax would happen here. For this test, just release immediately.
|
||||
# The epilogue will do mma_si_cons wait_and_advance then release.
|
||||
# After the second acquire returns, continue to acc_pipe commit.
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# EPILOGUE WARPS — same as Stage A but with mma_si wait+release added
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
|
||||
# ADDED: mma_si wait + release (simulating softmax)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
# (no actual softmax — just release immediately)
|
||||
si_handle.release()
|
||||
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtAcc_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 64
|
||||
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = a[:,:,0].float() @ b[:,:,0].float().T
|
||||
|
||||
import cutlass.torch as ct
|
||||
mA = ct.from_dlpack(a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(a))
|
||||
mB = ct.from_dlpack(b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(b))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = MmaSiTest(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mA, mB, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mA, mB, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('MMA+mma_si only test: cosine {:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
345
tests/test_mma_si_pv.py
Normal file
345
tests/test_mma_si_pv.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""
|
||||
Stage B test: MMA + mma_si + V TMA + PV MMA.
|
||||
Built incrementally from working test_mma_si_only.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class MmaSiPvTest:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) +
|
||||
cute.size_in_bytes(self.q_dtype, v_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
|
||||
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
||||
tma_c, tma_tc, self.cluster_layout_vmnk,
|
||||
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
||||
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
|
||||
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
|
||||
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gQ, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ═══ TMA LOAD WARP ═══
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ═══ MMA WARP ═══
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrQ, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
s0_handle = mma_si_prod.acquire_and_advance() # wait for softmax done
|
||||
|
||||
# PV MMA
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ═══ EPILOGUE WARPS ═══
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# TMEM load/store setup for softmax packing
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 → BF16 packing
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
|
||||
# Output epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_base.float()
|
||||
ref = qf @ kf.T @ vf.T
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = MmaSiPvTest(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('MMA+mma_si+PV test (V MN-major): cosine {:.6f}, max_err {:.6f} {}'.format(cos, max_err, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
303
tests/test_pv_mma_mn_major.py
Normal file
303
tests/test_pv_mma_mn_major.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Isolated test for Bug 1: PV MMA with V MN-major.
|
||||
|
||||
Only tests the PV MMA (P@V) with V as MN-major B-operand.
|
||||
No QK MMA, no identity softmax, no pipeline complexity.
|
||||
P comes from TMEM (a_source=TMEM), V comes from SMEM (b from TMA load).
|
||||
|
||||
Architecture:
|
||||
- TMA load V into SMEM
|
||||
- P pre-populated in TMEM (via small QK MMA or direct write)
|
||||
- PV MMA: P @ V → O in TMEM
|
||||
- Epilogue: TMEM → GMEM
|
||||
|
||||
For simplicity, P is computed via a QK MMA first (Q@K^T → P in TMEM),
|
||||
then PV MMA uses P from TMEM. No softmax — identity pass-through.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class PvMmaTest:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.mma_warp_id = 0
|
||||
self.tma_warp_id = 1
|
||||
self.threads_per_cta = 64
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[pv_test] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[pv_test] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
|
||||
# Compute epilogue tile from PV output (not QK)
|
||||
cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 0 # P = S (identity softmax, same TMEM)
|
||||
self.tmem_o0_offset = s_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = q.element_type; self.b_dtype = k.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
print(f"[pv_test] a_major (Q) = {self.a_major}")
|
||||
print(f"[pv_test] b_major (K) = {self.b_major}")
|
||||
print(f"[pv_test] v_major (V) = {self.v_major}")
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
# BUG 1 FIX: PV MMA uses V's MN-major mode
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
|
||||
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
||||
tma_c, tma_tc, self.cluster_layout_vmnk,
|
||||
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
||||
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32),
|
||||
tx_count=self.num_tmama_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=64)
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=0, is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gK = cute.local_tile(mK, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gQ, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
|
||||
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P from S TMEM — same location, MMA A-operand for PV
|
||||
tP = cute.make_tensor(tStS.iterator, self.p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = tOrP # P is at same TMEM offset as S (identity softmax)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# WARP 1: TMA load
|
||||
if warp_idx == self.tma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# WARP 0: MMA (both QK and PV)
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
|
||||
# QK MMA: Q @ K^T → S in TMEM
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrQ, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# PV MMA: P @ V → O in TMEM (identity softmax: P = S)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
# V: MN-major — (head_dim, seq) with strides (1, head_dim)
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_base.float()
|
||||
# Q@K^T = (128,128), then P@V: (128,128) @ (64,128).T = (128,64)
|
||||
# Wait — with MN-major V, the MMA interprets V as (head_dim, seq) MN-major
|
||||
# which means the MMA computes P @ V (not P @ V^T)
|
||||
# So reference is: (Q @ K^T) @ V where V is (64, 128) row-major
|
||||
# But V has strides (1, 64), so V is NOT row-major.
|
||||
# The kernel sees V as MN-major B-operand, which means:
|
||||
# PV MMA computes: P[m, k] * V[n, k] -> O[m, n]
|
||||
# This is P @ V^T in matrix notation
|
||||
# So reference: Q@K^T @ V^T where V^T is (128, 64)
|
||||
ref = qf @ kf.T @ vf.T # (128,128) @ (128,64) = (128,64)
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = PvMmaTest(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('PV MMA test (V MN-major, no softmax):')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
288
tests/test_softmax_only.py
Normal file
288
tests/test_softmax_only.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Test: QK + softmax packing only (no PV).
|
||||
Output is the softmax-packed P (BF16) stored to GMEM via epilogue.
|
||||
This tests whether the FP32→BF16 packing in TMEM works correctly.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class SoftmaxOnlyKernel:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, tiled_mma):
|
||||
mma_inst_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
||||
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_k * 4)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (tiled_mma.thr_id.shape,))
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1], self.mma_tiler[2])
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(tiled_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(tiled_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake, arch="sm_100")
|
||||
self.tilePlikeFP32 = self.mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # BF16 P location in TMEM
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = a.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
tiled_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
self._setup(tiled_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(tiled_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, tiled_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(tiled_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
thr_mma = tiled_mma.get_slice(0)
|
||||
tCgA = thr_mma.partition_A(gA); tCgB = thr_mma.partition_B(gB); tCgC = thr_mma.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = tiled_mma.make_fragment_A(sA); tCrB = tiled_mma.make_fragment_B(sB)
|
||||
acc_shape = thr_mma.partition_shape_C(self.mma_tiler[:2])
|
||||
tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage))
|
||||
|
||||
# S in TMEM
|
||||
tStS = thr_mma.make_fragment_C(acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# TMA WARP
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# MMA WARP
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(tiled_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
tiled_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
# Don't wait for softmax — just commit acc_pipe directly
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# EPILOGUE WARPS
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# TMEM load/store setup for softmax packing
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.mma_tiler[0], self.mma_tiler[1]))
|
||||
tScS = thr_mma.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# Wait for S ready, then do softmax packing
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
|
||||
# Output: read from the P location in TMEM, not S
|
||||
# Use the acc_pipe to store the BF16 P to GMEM
|
||||
# But acc_pipe is committed by MMA warp for S (not P)
|
||||
# So we need to use the S data from the epilogue
|
||||
# Actually, let's just output the softmax-packed result by reading P from TMEM
|
||||
# and storing via a second TMA store. This is complex. Let's skip for now
|
||||
# and just verify the QK output is correct.
|
||||
|
||||
# For now, just output S (the QK accumulator) — same as Stage A
|
||||
tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtAcc_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 64
|
||||
a = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
b = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = a[:,:,0].float() @ b[:,:,0].float().T
|
||||
|
||||
import cutlass.torch as ct
|
||||
mA = ct.from_dlpack(a).mark_layout_dynamic(leading_dim=ct.get_leading_dim(a))
|
||||
mB = ct.from_dlpack(b).mark_layout_dynamic(leading_dim=ct.get_leading_dim(b))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = SoftmaxOnlyKernel(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mA, mB, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mA, mB, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Softmax-only test (QK + packing, output S): cosine {:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
401
tests/test_stage_b_v13.py
Normal file
401
tests/test_stage_b_v13.py
Normal file
@@ -0,0 +1,401 @@
|
||||
"""
|
||||
Stage B v13: Two MMAs + Identity Softmax using FMHA's C-fragment packing pattern.
|
||||
|
||||
The key insight: the C→A "transform" is NOT a reordering — it's a PACKING.
|
||||
When you write 128 BF16 values packed into 64 FP32 words via the C-fragment
|
||||
composition (128, tilePlikeFP32) with St32x32bOp as FP32, the physical TMEM
|
||||
locations used are exactly the ones the PV MMA's A-fragment reads from.
|
||||
|
||||
FMHA pattern:
|
||||
1. Load S from TMEM via C-fragment layout (FP32, 128×128)
|
||||
2. Convert to BF16 and pack: FP32 backing tensor + BF16 recast view
|
||||
3. Store packed FP32 backing to TMEM via C-fragment composition with St32x32bOp(FP32)
|
||||
4. PV MMA reads from A-fragment TMEM — same physical locations as the packed BF16
|
||||
|
||||
Architecture:
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 packed → tcgen05.st C-fragment composition
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
# FMHA pattern: pv_mma_tiler = (M, QK_K, QK_N) — K=QK_N, N=head_dim
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# TMEM offsets (matching fmha.py)
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
|
||||
# FMHA TMEM layout: S0=0, P0=32, O0=256
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
self.tmem_alloc_cols = s_cols + o_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols = {s_cols}, o_cols = {o_cols}")
|
||||
print(f"[StageB] tmem_alloc_cols = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares SMEM with B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
# ── TMEM tensors ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# ── P A-fragment for PV MMA (matching fmha.py exactly) ──
|
||||
# tP uses C-fragment iterator but A-fragment (p_tmem_s) layout
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# ── Softmax store uses C-fragment composition (FMHA pattern) ──
|
||||
# tStS_P: C-fragment layout composed with (128, tilePlikeFP32)
|
||||
# This is where we store the packed BF16 P values
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tOrP.layout: {tOrP.layout}')
|
||||
print(f'[DIAG] tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[DIAG] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
# Wait for softmax to complete
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── 1. TMEM LOAD pipeline (C-fragment layout, FP32) ──
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# ── 2. TMEM STORE pipeline (C-fragment composition, FP32 St32x32bOp) ──
|
||||
# FMHA: compose the coordinate tensor with (128, tilePlikeFP32) for the store
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# Wait for QK scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# ── 3. LOAD scores from TMEM ──
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# ── 4. Identity softmax: convert FP32 → BF16 with packing ──
|
||||
# FMHA pattern: FP32 backing tensor + BF16 recast view
|
||||
# tTMEM_STORErS_x4: FP32 backing (store partition shape, 64 FP32 words)
|
||||
# tTMEM_STORErS_x4_e: BF16 recast view (load partition shape, 128 BF16 elements)
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout,
|
||||
)
|
||||
|
||||
# Convert all 128 FP32 values → 128 BF16, packed into 64 FP32 backing slots
|
||||
# The .load()/.store() on the recast view handles the packing automatically
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
|
||||
for j in range(frg_cnt):
|
||||
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
|
||||
# Identity: keep the value, just convert F32→BF16
|
||||
# (In real softmax, this is where exp(scores - row_max) happens)
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] # identity (no-op, just for clarity)
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
# ── 5. STORE packed BF16 to TMEM via C-fragment composition ──
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
# Identity softmax = just convert F32→BF16, so P = BF16(S) ≈ Q@K^T
|
||||
# O = P @ V = (Q @ K^T) @ V (same as Q @ K^T @ K, since V=K in this test)
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v13: FMHA C-fragment packing pattern (identity softmax)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
352
tests/test_stage_b_v14.py
Normal file
352
tests/test_stage_b_v14.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""
|
||||
Stage B v14: Two MMAs + Identity Softmax using backward FMHA's A-fragment store pattern.
|
||||
|
||||
The backward FMHA writes P to TMEM using the A-fragment layout directly:
|
||||
1. Load S from TMEM via C-fragment layout (FP32)
|
||||
2. Quantize FP32 -> BF16 (make_rmem_tensor with BF16, load/store)
|
||||
3. Reshape quantized to match store coordinate shape
|
||||
4. Store via St32x32bOp(BF16) to A-fragment TMEM layout
|
||||
5. PV MMA reads from the same A-fragment addresses
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
self.tmem_alloc_cols = s_cols + o_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
# TMEM tensors
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P A-fragment (backward FMHA pattern)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tdVrP_base = pv_thr.make_fragment_A(tP)
|
||||
tdVrP = tdVrP_base[(None, None, None, 0)]
|
||||
tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tdVrP = cute.make_tensor(tdVrP_iter, tdVrP.layout)
|
||||
tdVrP0 = cute.make_tensor(
|
||||
tdVrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tdVrP.layout)
|
||||
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tdVrP.layout: {tdVrP.layout}')
|
||||
print(f'[DIAG] tdVrP0.layout: {tdVrP0.layout}')
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# TMA WARP
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# MMA WARP
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tdVrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tdVrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# SOFTMAX / EPILOGUE WARPS
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# 1. TMEM LOAD pipeline (C-fragment, FP32)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. TMEM STORE pipeline (A-fragment, BF16 St32x32bOp)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tdVrP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tRT_tP = thr_store.partition_D(tdVrP0)
|
||||
cS_store = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tdVcS = pv_thr.partition_A(cS_store)
|
||||
tRT_cS = thr_store.partition_S(tdVcS)
|
||||
|
||||
# Wait for QK scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 3. LOAD scores from TMEM
|
||||
tTR_rST = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTR_rST)
|
||||
|
||||
# 4. Quantize FP32 -> BF16 (backward FMHA pattern)
|
||||
tRT_rST_bf16 = cute.make_rmem_tensor(tTR_rST.shape, self.q_dtype)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTR_rST) // frg_cnt
|
||||
tTR_rST_frg = cute.logical_divide(tTR_rST, cute.make_layout(frg_tile))
|
||||
tRT_rST_bf16_frg = cute.make_tensor(tRT_rST_bf16.iterator, tTR_rST_frg.layout)
|
||||
for j in range(frg_cnt):
|
||||
frg_vec = tTR_rST_frg[None, j].load()
|
||||
tRT_rST_bf16_frg[None, j].store(frg_vec.to(self.q_dtype))
|
||||
|
||||
# 5. Reshape to match store coordinate shape
|
||||
tRT_rST_reshaped = cute.make_tensor(
|
||||
tRT_rST_bf16.iterator, cute.make_layout(tRT_cS.shape))
|
||||
|
||||
# 6. STORE to A-fragment TMEM
|
||||
cute.copy(tiled_tmem_store, tRT_rST_reshaped, tRT_tP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v14: backward FMHA A-fragment store pattern (identity softmax)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
453
tests/test_stage_b_v16.py
Normal file
453
tests/test_stage_b_v16.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Original
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = 512 # FMHA: allocate max TMEM
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
# Introspect PV MMA atom
|
||||
print(f"[ATOM] PV MMA type: {type(pv_mma)}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}")
|
||||
print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}")
|
||||
print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
# Check a_src
|
||||
print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}")
|
||||
print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op}")
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tdVrP0 = cute.make_tensor(tdVrP.iterator + dtype_width_ratio * tmem_p0_offset, tdVrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tdVrP_base = pv_thr.make_fragment_A(tP)
|
||||
tdVrP = tdVrP_base[(None, None, None, 0)]
|
||||
tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tdVrP = cute.make_tensor(tdVrP_iter, tdVrP.layout)
|
||||
tdVrP0 = cute.make_tensor(
|
||||
tdVrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tdVrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tdVrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tdVrP.layout: {tdVrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tdVrP cosize: {cute.cosize(tdVrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tdVrP.size: {cute.size(tdVrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tdVrP0.layout: {tdVrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tdVrP0 cosize: {cute.cosize(tdVrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tdVrP0.size = {cute.size(tdVrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tdVrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tdVrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (backward FMHA: A-fragment, BF16 St32x32bOp)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tdVrP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tRT_tP = thr_store.partition_D(tdVrP0)
|
||||
cS_store = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tdVcS = pv_thr.partition_A(cS_store)
|
||||
tRT_cS = thr_store.partition_S(tdVcS)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. Quantize FP32 -> BF16 (backward FMHA pattern)
|
||||
tRT_rST_bf16 = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTR_rST_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tRT_rST_bf16_frg = cute.make_tensor(tRT_rST_bf16.iterator, tTR_rST_frg.layout)
|
||||
for j in range(frg_cnt):
|
||||
frg_vec = tTR_rST_frg[None, j].load()
|
||||
tRT_rST_bf16_frg[None, j].store(frg_vec.to(self.q_dtype))
|
||||
# 6. Reshape and store to A-fragment TMEM
|
||||
tRT_rST_reshaped = cute.make_tensor(
|
||||
tRT_rST_bf16.iterator, cute.make_layout(tRT_cS.shape))
|
||||
cute.copy(tiled_tmem_store, tRT_rST_reshaped, tRT_tP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
450
tests/test_stage_b_v17.py
Normal file
450
tests/test_stage_b_v17.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Original
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
# Introspect PV MMA atom
|
||||
print(f"[ATOM] PV MMA type: {type(pv_mma)}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}")
|
||||
print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}")
|
||||
print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
# Check a_src
|
||||
print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}")
|
||||
print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op}")
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. IDENTITY: F32 → BF16
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Store into A-layout
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
452
tests/test_stage_b_v18.py
Normal file
452
tests/test_stage_b_v18.py
Normal file
@@ -0,0 +1,452 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Original
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = 512 # FMHA-style: allocate max TMEM # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
# Introspect PV MMA atom
|
||||
print(f"[ATOM] PV MMA type: {type(pv_mma)}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}")
|
||||
print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}")
|
||||
print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
# Check a_src
|
||||
print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}")
|
||||
print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op}")
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (A-fragment, BF16 St32x32bOp — no recast)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOrP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tRT_tP = thr_store.partition_D(tOrP0)
|
||||
cS_store = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tAcS = pv_thr.partition_A(cS_store)
|
||||
tRT_cS = thr_store.partition_S(tAcS)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. Quantize FP32 -> BF16 (backward FMHA pattern)
|
||||
tRT_rST_bf16 = cute.make_rmem_tensor(tTMEM_LOADrS.shape, self.q_dtype)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTR_rST_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tRT_rST_bf16_frg = cute.make_tensor(tRT_rST_bf16.iterator, tTR_rST_frg.layout)
|
||||
for j in range(frg_cnt):
|
||||
frg_vec = tTR_rST_frg[None, j].load()
|
||||
tRT_rST_bf16_frg[None, j].store(frg_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Reshape and store to A-fragment TMEM
|
||||
tRT_rST_reshaped = cute.make_tensor(
|
||||
tRT_rST_bf16.iterator, cute.make_layout(tRT_cS.shape))
|
||||
cute.copy(tiled_tmem_store, tRT_rST_reshaped, tRT_tP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
450
tests/test_stage_b_v19.py
Normal file
450
tests/test_stage_b_v19.py
Normal file
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Original
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
# Introspect PV MMA atom
|
||||
print(f"[ATOM] PV MMA type: {type(pv_mma)}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}")
|
||||
print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}")
|
||||
print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
# Check a_src
|
||||
print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}")
|
||||
print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op}")
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. IDENTITY: F32 → BF16
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Store into A-layout
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
362
tests/test_stage_b_v20.py
Normal file
362
tests/test_stage_b_v20.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
Stage B v20: FMHA-matching test with head_dim=64, proper V layout.
|
||||
|
||||
KEY INSIGHT: The A-fragment ((128,16),1,(4,2)):((65536,1),0,(16,64)) is SEQUENTIAL
|
||||
when flattened in CuTe order: addr = m*65536 + k0 + 16*k1 + 64*k2 = m*65536 + k.
|
||||
So the C-fragment composition store aliases the SAME TMEM as the A-fragment read.
|
||||
|
||||
Previous -0.02 cosine was caused by V dimension mismatch:
|
||||
pv_mma_tiler=(128,64,128) expects V with N=64 (head_dim),
|
||||
but the square 128x128 test had V=K (N=128).
|
||||
|
||||
This test: Q=(128,64), K=(128,64), V=(64,128), O=(128,64)
|
||||
|
||||
FOOTGUN: St32x32bOp MUST use Float32, NOT BFloat16!
|
||||
The 16-bit values are packed via recast_ptr view.
|
||||
St32x32bOp(BFloat16) causes ILLEGAL MEMORY ACCESS.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v20] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v20] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
self.tmem_alloc_cols = s_cols + o_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
# Separate TMA for V (uses pv_mma_tiler and v_smem layout)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_v, tma_tv, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_v, mV, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
# V partition (from mV with pv_mma_tiler and tma_v)
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# 1. TMEM LOAD (C-fragment, FP32)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. TMEM STORE — FOOTGUN: Float32 atom, NOT BFloat16!
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# 5. Identity softmax: FP32->BF16 packing (EXACT FMHA pattern)
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.TmaStorePipeline.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
# FMHA-matching dimensions: head_dim=64, seq=128
|
||||
# Q: (128, 64) K-major, K: (128, 64) K-major
|
||||
# V: (64, 128) MN-major (transposed!) — FMHA requires v_major_mode=OperandMajorMode.MN
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(head_dim, n, 1, dtype=torch.bfloat16, device='cuda') # Transposed!
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float(); vf = v[:,:,0].float()
|
||||
# Q@K^T = (128,128), P@V = (128,64)
|
||||
ref = qf @ kvf.T @ vf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v20: FMHA-matching head_dim=64, proper V layout')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
314
tests/test_stage_b_v22.py
Normal file
314
tests/test_stage_b_v22.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Stage B v22: Q@K^T → S in TMEM, P@V → O in TMEM (identity softmax = S used as P)
|
||||
|
||||
Based on working Stage A v2 pattern. Two MMAs on the MMA warp, no softmax pipeline.
|
||||
P stays in TMEM between QK and PV — no copy, no packing, just use S directly as P.
|
||||
|
||||
Bug 1 fix: V is MN-major, PV MMA uses v_major (OperandMajorMode.MN).
|
||||
The PV MMA's A-operand (P) comes from TMEM (same as S accumulator).
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentityKernel:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = self.o_dtype # alias for epilogue_tma_store
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v22] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v22] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_o0_offset = s_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem) + cute.size_in_bytes(self.b_dtype, v_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = q.element_type; self.b_dtype = k.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
print(f"[v22] a_major (Q) = {self.a_major}")
|
||||
print(f"[v22] b_major (K) = {self.b_major}")
|
||||
print(f"[v22] v_major (V) = {self.v_major}")
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
# BUG 1 FIX: PV MMA b_leading_mode = v_major (MN), NOT b_major (K)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, # Bug 1 fix
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
|
||||
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
||||
tma_c, tma_tc, self.cluster_layout_vmnk,
|
||||
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
||||
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
|
||||
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
|
||||
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gQ, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
# V partition — pv_mma with pv_mma_tiler
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
# QK accumulator (S) in TMEM
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
# PV accumulator (O) in TMEM
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P tensor for PV MMA A-operand (from TMEM, same as S)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
# P is at the same TMEM location as S (identity softmax: no copy/packing)
|
||||
tOrP0 = tOrP
|
||||
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ═══ TMA LOAD WARP ═══
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ═══ MMA WARP: QK then PV ═══
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
|
||||
# --- QK MMA: Q @ K^T → S in TMEM ---
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrQ, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# --- PV MMA: P @ V → O in TMEM ---
|
||||
# Identity softmax: P = S (FP32 in TMEM, used directly as A-operand)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ═══ EPILOGUE WARPS ═══
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
# V: MN-major — (head_dim, seq) with strides (1, head_dim)
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_base.float()
|
||||
# Q@K^T = (128,128), then PV MMA: P(128,128) @ V(64,128) MN-major
|
||||
# MN-major B-operand means the MMA interprets V as V[k,n] where k is contiguous
|
||||
# So PV computes: P[m,k] * V[n,k] = O[m,n] = P @ V^T
|
||||
ref = qf @ kf.T @ vf.T # (128,64)
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentityKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v22: Q@K^T + P@V (V MN-major, identity softmax)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
364
tests/test_stage_b_v22_bug1fix.py
Normal file
364
tests/test_stage_b_v22_bug1fix.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Stage B v22: Bug 1 Fix — V B-Operand Must Be MN-Major
|
||||
|
||||
Fix over v20: PV MMA uses V's major mode (MN) instead of K's major mode (K).
|
||||
V is shaped (head_dim, seq) = (64, 128) with strides (1, 64) → OperandMajorMode.MN.
|
||||
K is shaped (seq, head_dim) = (128, 64) with strides (64, 1) → OperandMajorMode.K.
|
||||
|
||||
These are DIFFERENT. The PV MMA's b_leading_mode MUST come from V, not from K.
|
||||
|
||||
Also: separate TMA descriptor for V (already in v20), separate SMEM layout for V.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v22] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v22] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
# V uses pv_mma with MN-major B — its own SMEM layout
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
self.tmem_alloc_cols = s_cols + o_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() # V's major mode — MN!
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
print(f"[v22] a_major (Q) = {self.a_major}")
|
||||
print(f"[v22] b_major (K) = {self.b_major}")
|
||||
print(f"[v22] v_major (V) = {self.v_major}")
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
# BUG 1 FIX: PV MMA uses V's major mode (MN), NOT K's major mode (K)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
# Separate TMA for V — uses pv_mma and pv_mma_tiler
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_v, tma_tv, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_v, mV, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
# V partition — uses pv_mma, pv_mma_tiler, and tma_v
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# 1. TMEM LOAD (C-fragment, FP32)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. TMEM STORE — FOOTGUN: Float32 atom, NOT BFloat16!
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# 5. Identity softmax: FP32->BF16 packing (EXACT FMHA pattern)
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
# FMHA-matching dimensions: head_dim=64, seq=128
|
||||
# Q: (128, 64) K-major, K: (128, 64) K-major
|
||||
# V: (64, 128) MN-major — FMHA requires v_major_mode=OperandMajorMode.MN
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1) # MN-major: strides (1, 64)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float(); vf = v[:,:,0].float()
|
||||
# Q@K^T = (128,128), P@V = (128,64)
|
||||
ref = qf @ kvf.T @ vf.T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v22: Bug 1 fix — V MN-major')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
364
tests/test_stage_b_v23.py
Normal file
364
tests/test_stage_b_v23.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Stage B v23: FMHA-matching test with head_dim=64, proper V layout.
|
||||
|
||||
KEY INSIGHT: The A-fragment ((128,16),1,(4,2)):((65536,1),0,(16,64)) is SEQUENTIAL
|
||||
when flattened in CuTe order: addr = m*65536 + k0 + 16*k1 + 64*k2 = m*65536 + k.
|
||||
So the C-fragment composition store aliases the SAME TMEM as the A-fragment read.
|
||||
|
||||
Previous -0.02 cosine was caused by V dimension mismatch:
|
||||
pv_mma_tiler=(128,64,128) expects V with N=64 (head_dim),
|
||||
but the square 128x128 test had V=K (N=128).
|
||||
|
||||
This test: Q=(128,64), K=(128,64), V=(64,128), O=(128,64)
|
||||
|
||||
FOOTGUN: St32x32bOp MUST use Float32, NOT BFloat16!
|
||||
The 16-bit values are packed via recast_ptr view.
|
||||
St32x32bOp(BFloat16) causes ILLEGAL MEMORY ACCESS.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v23] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v23] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
self.tmem_alloc_cols = s_cols + o_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() # Bug 1: V MN-major
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, # Bug 1 fix: V MN-major
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
# Separate TMA for V (uses pv_mma_tiler and v_smem layout)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_v, tma_tv, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_v, mV, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
# V partition (from mV with pv_mma_tiler and tma_v)
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# 1. TMEM LOAD (C-fragment, FP32)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. TMEM STORE — FOOTGUN: Float32 atom, NOT BFloat16!
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# 5. Identity softmax: FP32->BF16 packing (EXACT FMHA pattern)
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
# FMHA-matching dimensions: head_dim=64, seq=128
|
||||
# Q: (128, 64) K-major, K: (128, 64) K-major
|
||||
# V: (64, 128) MN-major (transposed!) — FMHA requires v_major_mode=OperandMajorMode.MN
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1) # MN-major: strides (1, 64)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float(); vf = v_base.float()
|
||||
# Q@K^T = (128,128), P@V = (128,64)
|
||||
ref = qf @ kvf.T @ vf.T # P@V with V MN-major = Q@K^T @ V^T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v23: FMHA-matching head_dim=64, proper V layout')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
377
tests/test_stage_b_v24.py
Normal file
377
tests/test_stage_b_v24.py
Normal file
@@ -0,0 +1,377 @@
|
||||
"""
|
||||
Stage B v24: Q@K^T + identity softmax P packing + P@V
|
||||
Bug 1 fix: V MN-major
|
||||
Bug 2 fix: FP32->BF16 P packing (C-fragment composition store)
|
||||
Pipeline fix: Use NamedBarrier instead of PipelineUmmaAsync for mma_si sync
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentityKernel:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
# Named barrier IDs for mma_si coordination
|
||||
self.scores_ready_bar_id = 4 # MMA warp signals S ready
|
||||
self.softmax_done_bar_id = 5 # Epilogue warps signal softmax done
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v24] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v24] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # P region in TMEM (for BF16 packing)
|
||||
self.tmem_o0_offset = s_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = q.element_type; self.b_dtype = k.element_type
|
||||
self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
print(f"[v24] a_major (Q) = {self.a_major}")
|
||||
print(f"[v24] b_major (K) = {self.b_major}")
|
||||
print(f"[v24] v_major (V) = {self.v_major}")
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
|
||||
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
||||
tma_c, tma_tc, self.cluster_layout_vmnk,
|
||||
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
||||
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
|
||||
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
# Named barriers for MMA ↔ softmax coordination
|
||||
scores_ready_bar = pipeline.NamedBarrier(
|
||||
barrier_id=self.scores_ready_bar_id,
|
||||
num_threads=32 * (1 + len(self.epilogue_warp_id))) # MMA warp + epilogue warps
|
||||
softmax_done_bar = pipeline.NamedBarrier(
|
||||
barrier_id=self.softmax_done_bar_id,
|
||||
num_threads=32 * (1 + len(self.epilogue_warp_id))) # MMA warp + epilogue warps
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
|
||||
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gQ, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P from TMEM — at tmem_p0_offset, read by PV MMA as A-operand
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ═══ TMA LOAD WARP ═══
|
||||
if warp_idx == self.tma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ═══ MMA WARP: QK → signal → wait → PV ═══
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
|
||||
# --- QK MMA: Q @ K^T → FP32 S in TMEM ---
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrQ, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Signal: S is ready (MMA warp arrives, waits for epilogue warps)
|
||||
scores_ready_bar.arrive_and_wait()
|
||||
|
||||
# Wait: softmax done (epilogue warps arrive, MMA warp waits)
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
|
||||
# --- PV MMA: BF16 P @ V → FP32 O in TMEM ---
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ═══ EPILOGUE WARPS: softmax packing + output store ═══
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# --- TMEM LOAD (C-fragment, FP32) ---
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# --- TMEM STORE (C-fragment composition, FP32 atom with BF16 recast) ---
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# Wait for S ready
|
||||
scores_ready_bar.arrive_and_wait()
|
||||
|
||||
# Load FP32 S from TMEM
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# FP32→BF16 packing (C-fragment composition store pattern)
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Signal: softmax done
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
|
||||
# --- Output epilogue ---
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
# V: MN-major — (head_dim, seq) with strides (1, head_dim)
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1) # MN-major: strides (1, 64)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_base.float()
|
||||
# Q@K^T = (128,128), PV MMA: P(128,128) @ V(64,128) MN-major
|
||||
# MN-major B = V[k,n] where k contiguous => MMA computes P @ V^T
|
||||
ref = qf @ kf.T @ vf.T # (128,64)
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentityKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v24: Q@K^T + identity softmax (NamedBarrier) + P@V (V MN-major)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
380
tests/test_stage_b_v25.py
Normal file
380
tests/test_stage_b_v25.py
Normal file
@@ -0,0 +1,380 @@
|
||||
"""
|
||||
Stage B v25: Q@K^T + identity softmax P packing + P@V
|
||||
Uses PipelineAsync (not Umma) for mma_si sync — two separate one-shot pipelines.
|
||||
Bug 1 fix: V MN-major.
|
||||
Bug 2 fix: FP32→BF16 P packing (C-fragment composition store).
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentityKernel:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v25] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v25] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = q.element_type; self.b_dtype = k.element_type
|
||||
self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
print(f"[v25] a_major (Q) = {self.a_major}")
|
||||
print(f"[v25] b_major (K) = {self.b_major}")
|
||||
print(f"[v25] v_major (V) = {self.v_major}")
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
|
||||
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
||||
tma_c, tma_tc, self.cluster_layout_vmnk,
|
||||
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
||||
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
|
||||
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
# Two separate mbarriers for mma→softmax and softmax→mma signaling
|
||||
scores_ready_bar: cute.struct.MemRange[cutlass.Int64, 2] # MMA→softmax
|
||||
softmax_done_bar: cute.struct.MemRange[cutlass.Int64, 2] # softmax→MMA
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
# PipelineAsync for scores_ready (MMA → softmax)
|
||||
scores_ready_prod, scores_ready_cons = pipeline.PipelineAsync.create(
|
||||
barrier_storage=st.scores_ready_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
).make_participants()
|
||||
|
||||
# PipelineAsync for softmax_done (softmax → MMA)
|
||||
softmax_done_prod, softmax_done_cons = pipeline.PipelineAsync.create(
|
||||
barrier_storage=st.softmax_done_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
).make_participants()
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
|
||||
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gQ, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ═══ TMA LOAD WARP ═══
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ═══ MMA WARP ═══
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
|
||||
# QK MMA
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrQ, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Signal: S ready → epilogue warps can start softmax
|
||||
scores_ready_prod.acquire_and_advance().commit()
|
||||
scores_ready_prod.tail()
|
||||
|
||||
# Wait: softmax done → can start PV MMA
|
||||
softmax_done_cons.reset()
|
||||
softmax_done_cons.wait_and_advance().release()
|
||||
|
||||
# PV MMA
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ═══ EPILOGUE WARPS ═══
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# Wait: S ready
|
||||
scores_ready_cons.reset()
|
||||
scores_ready_cons.wait_and_advance().release()
|
||||
|
||||
# Load FP32 S
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# FP32→BF16 packing
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Signal: softmax done
|
||||
softmax_done_prod.acquire_and_advance().commit()
|
||||
softmax_done_prod.tail()
|
||||
|
||||
# Output epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_base.float()
|
||||
ref = qf @ kf.T @ vf.T
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentityKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v25: PipelineAsync for mma↔softmax sync, V MN-major')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
367
tests/test_stage_b_v26.py
Normal file
367
tests/test_stage_b_v26.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Stage B v26: Q@K^T + identity softmax P packing + P@V
|
||||
Uses SMEM spin-wait flags for mma↔softmax sync (no PipelineUmmaAsync).
|
||||
Bug 1 fix: V MN-major.
|
||||
Bug 2 fix: FP32→BF16 P packing (C-fragment composition store).
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentityKernel:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v26] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v26] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = q.element_type; self.b_dtype = k.element_type
|
||||
self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
|
||||
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
||||
tma_c, tma_tc, self.cluster_layout_vmnk,
|
||||
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
||||
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
|
||||
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
# Use acc_pipe's barrier for BOTH QK→softmax and PV→epilogue signaling
|
||||
# The trick: re-acquire acc_pipe on the producer side after softmax
|
||||
# This works because PipelineUmmaAsync with 1 stage blocks producer
|
||||
# until consumer releases
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
|
||||
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gQ, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ═══ TMA LOAD WARP (no tmem.wait_for_alloc!) ═══
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ═══ MMA WARP ═══
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
# --- QK MMA ---
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrQ, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# --- Signal acc_pipe: S ready (producer commit) ---
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
|
||||
# --- Wait for softmax done (producer re-acquire blocks until consumer releases) ---
|
||||
acc_prod_st2 = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st2)
|
||||
|
||||
# --- PV MMA ---
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st2)
|
||||
acc_prod_st2.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st2)
|
||||
|
||||
# ═══ EPILOGUE WARPS ═══
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# TMEM load/store setup
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# --- Wait for S ready (consumer wait) ---
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
acc_pipe.consumer_wait(acc_cons_st)
|
||||
|
||||
# --- Softmax: FP32→BF16 P packing ---
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# --- Release acc_pipe: softmax done (consumer release) ---
|
||||
acc_pipe.consumer_release(acc_cons_st)
|
||||
acc_cons_st.advance()
|
||||
|
||||
# --- Wait for PV done (consumer wait on 2nd producer commit) ---
|
||||
acc_cons_st2 = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
acc_pipe.consumer_wait(acc_cons_st2)
|
||||
|
||||
# --- Output epilogue ---
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st2 = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st2, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
acc_pipe.consumer_release(acc_cons_st2)
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_base.float()
|
||||
ref = qf @ kf.T @ vf.T
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentityKernel(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v26: Reuse acc_pipe for 2-phase sync, V MN-major')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
369
tests/test_stage_b_v27.py
Normal file
369
tests/test_stage_b_v27.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
Stage B v27: Based on v20 with Bug 1 fix (V MN-major).
|
||||
Fixes over v20:
|
||||
1. V MN-major + pv_mma uses v_major
|
||||
2. PipelineTmaStore (not TmaStorePipeline)
|
||||
3. TMA warp does NOT call tmem.wait_for_alloc
|
||||
4. mma_si PipelineUmmaAsync without cta_layout_vmnk (like FMHA)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[v27] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[v27] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q: cute.Tensor, k: cute.Tensor, v: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.q_dtype = q.element_type; self.b_dtype = k.element_type
|
||||
self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
print(f"[v27] a_major (Q) = {self.a_major}")
|
||||
print(f"[v27] b_major (K) = {self.b_major}")
|
||||
print(f"[v27] v_major (V) = {self.v_major}")
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
||||
|
||||
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
||||
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
||||
tma_c, tma_tc, self.cluster_layout_vmnk,
|
||||
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
||||
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
|
||||
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
# mma_si pipeline — NO cta_layout_vmnk (matches FMHA)
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.b_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
|
||||
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gQ, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
|
||||
tCgV = pv_thr.partition_B(gV)
|
||||
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ═══ TMA LOAD WARP (no tmem.wait_for_alloc) ═══
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ═══ MMA WARP ═══
|
||||
if warp_idx == self.mma_warp_id:
|
||||
if tidx == 128: print("[MMA] before tmem.wait_for_alloc")
|
||||
tmem.wait_for_alloc()
|
||||
if tidx == 128: print("[MMA] after tmem.wait_for_alloc")
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
if tidx == 128: print("[MMA] before mma_si acquire 1")
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
if tidx == 128: print("[MMA] after mma_si_prod acquire 1")
|
||||
if tidx == 128: print("[MMA] after mma_si_prod acquire 2 (wait for softmax)")
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrQ, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
if tidx == 128: print("[MMA] after s0_handle commit")
|
||||
if tidx == 128: print("[MMA] before mma_si acquire 1")
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
if tidx == 128: print("[MMA] after mma_si_prod acquire 1")
|
||||
if tidx == 128: print("[MMA] after mma_si_prod acquire 2 (wait for softmax)")
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ═══ EPILOGUE WARPS ═══
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
if tidx == 128: print("[MMA] after tmem.wait_for_alloc")
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
if tidx == 0: print("[EPI] before mma_si_cons wait")
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
if tidx == 0: print("[EPI] before si release")
|
||||
si_handle.release()
|
||||
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, head_dim = 128, 128, 64
|
||||
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
||||
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
||||
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_base.float()
|
||||
ref = qf @ kf.T @ vf.T
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v27: V MN-major + no cta_layout on mma_si')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user