Stage B: C-fragment vs A-fragment TMEM layout mismatch diagnosed
Key finding: C-fragment and A-fragment use different physical TMEM address mappings. St32x32bOp with C-fragment writes to C-layout addresses, but PV MMA reads from A-layout addresses. Forward FMHA recast validated FP16 only, not BF16. Working: FP32 ld/st roundtrip, BF16 elemwise, BF16 recast ld S0->st S1 (all cos 0.999999) Broken: C-frag st + A-frag read (NaN), A-frag store + PV MMA (cos -0.02) Next: Fix register data flow (128 FP16/thread load vs 64 BF16/thread store mismatch)
This commit is contained in:
170
README.md
170
README.md
@@ -18,15 +18,17 @@ cutedsl/
|
||||
└── dense_blockscaled_gemm_persistent.py # REFERENCE: Blackwell TMEM/tcgen05 GEMM
|
||||
|
||||
tests/
|
||||
├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM
|
||||
├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + identity softmax (runs, wrong output)
|
||||
├── test_stage_b_minimal.py # ✅ Stage B minimal: two MMAs, no softmax (runs, NaN expected)
|
||||
├── test_stage_b_pipeline_only.py # ✅ Stage B pipeline-only: PipelineUmmaAsync, no ld/st (runs, NaN expected)
|
||||
├── diag_tmem.py # Diagnostic: TMEM layout inspection
|
||||
├── test_stage_b_v6.py # ❌ Stage B v6 (hardcoded offsets, crashes)
|
||||
├── test_stage_a_qk.py # ❌ Stage A v1 (broken, superseded by v2)
|
||||
├── test_stage_a_minimal.py # ❌ Stage A minimal (broken, superseded by v2)
|
||||
├── test_attention_path_b200.py # Full attention path test (uses naive BF16 attn)
|
||||
├── test_stage_a_v2.py # ✅ Stage A: bare Q@K^T via tcgen05.mma → TMEM → GMEM
|
||||
├── test_stage_b_v7.py # 🔨 Stage B: two MMAs + C-fragment softmax (runs, wrong output)
|
||||
├── test_stage_b_afrag2.py # 🔨 Stage B: A-fragment store pattern (compiles, wrong output)
|
||||
├── test_tmem_pure_fp32.py # ✅ FP32 ld→st roundtrip on C-fragment: cosine 0.999999
|
||||
├── test_bf16_elemwise.py # ✅ FP32→BF16→FP32 elemwise + FP32 st: cosine 0.999999
|
||||
├── test_recast_minimal.py # ✅ BF16 recast ld S0→st S1 via C-fragment: cosine 0.999999
|
||||
├── test_bf16_recast_simple.py # ❌ BF16 recast ld/st same region (S0): zero (can't overwrite MMA output)
|
||||
├── test_tmem_copy_roundtrip.py # ❌ BF16 recast + C→A mismatch: zero
|
||||
├── test_stage_b_final.py # ❌ C-fragment st + A-fragment read: NaN (physical layout mismatch)
|
||||
├── test_afrag_roundtrip.py # ❌ A-frag st corrupts S0 (overlapping TMEM region)
|
||||
├── diag_tmem.py # Diagnostic: TMEM layout inspection
|
||||
└── ...
|
||||
```
|
||||
|
||||
@@ -46,68 +48,66 @@ Validates the full tcgen05.mma → TMEM → epilogue → GMEM path:
|
||||
- TmemAllocator for TMEM allocation/deallocation
|
||||
- utils.gemm.sm100.epilogue_tma_store for the TMEM→reg→SMEM→TMA→GMEM epilogue
|
||||
|
||||
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20)
|
||||
### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS (May 20-21)
|
||||
|
||||
**Latest**: `tests/test_stage_b_v7.py`
|
||||
**Status**: Kernel compiles and runs without crashing. Identity softmax produces wrong output (cosine ≈ -0.02).
|
||||
**Core Problem**: The C-fragment (MMA accumulator) and A-fragment (MMA A-operand from TMEM) use **different physical TMEM address mappings** for the same logical (M,K) position. The softmax writes P via one mapping, but the PV MMA reads via the other. This produces garbage.
|
||||
|
||||
**What was fixed today:**
|
||||
#### What's Been Proven
|
||||
|
||||
1. **PipelineUmmaAsync consumer group size crash (THE bug):**
|
||||
`PipelineUmmaAsync` with `Agent.Thread` requires **thread count** (128), NOT warp count (4), for the consumer group. fmha.py uses `32 * len(softmax_warp_ids) = 128`. Using 4 caused `CUDA_ERROR_LAUNCH_FAILED` (not a deadlock — the barrier reached wrong threshold causing illegal TMEM access).
|
||||
|
||||
```python
|
||||
# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4)
|
||||
# CORRECT (matches fmha.py):
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 128)
|
||||
```
|
||||
| Test | Pattern | Result | Why |
|
||||
|------|---------|--------|-----|
|
||||
| test_tmem_pure_fp32 | FP32 ld→st, same C-fragment layout | ✅ cos=0.999999 | C-fragment addresses self-consistent |
|
||||
| test_bf16_elemwise | FP32→BF16→FP32 elemwise, C-fragment st | ✅ cos=0.999999 | BF16 conversion works, C-fragment st works |
|
||||
| test_recast_minimal | BF16 recast ld S0→st S1, C-fragment | ✅ cos=0.999999 | Recast works when writing to different region |
|
||||
| test_bf16_recast_simple | BF16 recast ld/st same region S0 | ❌ zero | Can't overwrite MMA output in same region |
|
||||
| test_stage_b_final | C-fragment st → A-fragment read (S1) | ❌ NaN | C-layout ≠ A-layout physical addresses |
|
||||
| test_stage_b_afrag2 | A-fragment st (backward FMHA pattern) | ❌ cos=-0.02 | Store + PV MMA layout compatible, but register data flow wrong |
|
||||
|
||||
2. **TMEM offset computation (no more hardcoding):**
|
||||
- `s_cols = find_tmem_tensor_col_offset(tStS) = 128` — QK accumulator physical TMEM columns
|
||||
- `o_cols = find_tmem_tensor_col_offset(tOtO) = 128` — PV accumulator physical TMEM columns
|
||||
- `tmem_s0_offset = 0, tmem_p0_offset = 32, tmem_o0_offset = 128` — matches fmha.py
|
||||
- `find_tmem_tensor_col_offset(tOrP_sliced) = 32800 = 0x8020` — 0x8000 is TMEM space tag, column offset = 32
|
||||
- Total: 256 TMEM cols (verified by `get_num_tmem_alloc_cols`)
|
||||
#### Root Cause: C-fragment vs A-fragment Physical TMEM Layout
|
||||
|
||||
3. **P fragment construction (matching fmha.py):**
|
||||
```python
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) # A-layout from PV MMA
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator + 2 * tmem_p0_offset, tOrP.layout)
|
||||
```
|
||||
Previously used `cute.composition` on C-layout — wrong, must use PV MMA's A-layout.
|
||||
From the CUTLASS source (`mma_traits_sm100.hpp`):
|
||||
|
||||
4. **V SMEM aliasing:**
|
||||
V shares the same SMEM as K with a different layout interpretation:
|
||||
```python
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # Uses MN-major V layout
|
||||
```
|
||||
**C-fragment (MMA accumulator, FP32):**
|
||||
- Layout: `((128,128),1,1):((65536,1),0,0)` — **virtual** layout
|
||||
- Physical TMEM addresses determined by the MMA hardware's accumulator write path
|
||||
- St32x32bOp with C-fragment layout writes to C-fragment physical addresses
|
||||
|
||||
**What's still broken:**
|
||||
**A-fragment (MMA A-operand from TMEM, BF16, K-major, M=128):**
|
||||
- Layout: `((128,16),1,4):((65536,1),0,16)` — **physical** TMEM layout
|
||||
- A[m, k_inner] → `tmem[dp=m, col=base + 16*mma_k + k_inner]`
|
||||
- BK=64 = 4 K=16 MMA atoms, NOT one K=64 atom
|
||||
- The 4D fragment partition order is NOT the physical TMEM order
|
||||
|
||||
The identity softmax C→A layout transform produces garbage output (cosine ≈ -0.02). The kernel runs, Stage A (Q@K^T) gives cosine 0.999999, but the full (Q@K^T)@V pipeline is wrong. The issue is in the tcgen05.ld/st identity softmax path — either the ld/st copy atoms, the register conversion, or the A-layout write positions are incorrect.
|
||||
**The St32x32bOp with C-fragment composition writes to C-layout physical addresses. The PV MMA reads from A-layout physical addresses. These are different physical locations.**
|
||||
|
||||
**Bisection results:**
|
||||
- ✅ Stage B minimal (no pipeline, no softmax): runs, NaN (expected — no C→A transform)
|
||||
- ✅ Stage B pipeline-only (PipelineUmmaAsync, no ld/st): runs, NaN (expected)
|
||||
- 🔨 Stage B full (identity softmax): runs, cosine -0.02 (wrong — softmax transform is broken)
|
||||
- All three crash with consumer_group=4, all run with consumer_group=128
|
||||
#### Forward FMHA's Approach (FP16 Only!)
|
||||
|
||||
**TMEM layout diagnostic data:**
|
||||
Forward FMHA uses a recast pattern to pack 2×FP16 into 1×FP32 register, then St32x32bOp writes to a C-fragment composition subview. **But forward FMHA explicitly rejects BF16:**
|
||||
```python
|
||||
if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}:
|
||||
raise ValueError(in_dtype must be Float8E4M3FN or Float16)
|
||||
```
|
||||
QK accumulator C fragment:
|
||||
tStS.layout = ((128,128),1,1):((65536,1),0,0)
|
||||
cute.size = 16384, cute.cosize = 8323200
|
||||
find_tmem_tensor_col_offset = 128
|
||||
The recast softmax path is validated for FP16, NOT BF16. Our BF16 use is outside the tested path.
|
||||
|
||||
PV A-fragment (P operand):
|
||||
tOrP_sliced.layout = ((128,16),1,4):((65536,1),0,16)
|
||||
cute.size = 8192, cute.cosize = 8323136
|
||||
find_tmem_tensor_col_offset = 32800 = 0x8020 (0x8000 tag + col 32)
|
||||
```
|
||||
#### Backward FMHA's Approach (BF16 Supported)
|
||||
|
||||
Backward FMHA writes dV to TMEM using the A-fragment layout:
|
||||
1. `tdVrP_iter = cute.recast_ptr(tSTtST.iterator, dtype=self.element_dtype)` — recast C-fragment iterator to BF16
|
||||
2. `tdVrP = cute.make_tensor(tdVrP_iter, tOrP.layout)` — A-fragment layout, C-fragment base
|
||||
3. `tmem_store_atom = cute.make_copy_atom(St32x32bOp(Repetition(8)), self.element_dtype)` — BF16 store atom
|
||||
4. Quantize via `make_rmem_tensor(input.shape, element_dtype)` + `.load()/.store(v.to(element_dtype))` — true BF16 register, NOT recast
|
||||
5. Reshape: `cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape))` — match store partition shape
|
||||
|
||||
This compiles and runs for us (no crash), but the output is still wrong (cosine -0.02). The remaining issue is the **register layout mismatch**:
|
||||
- Load partition (C-fragment): 128 FP32 values per thread (full 128×128 QK tile)
|
||||
- Store partition (A-fragment): 64 BF16 values per thread (128×64 P tile for PV MMA K=64)
|
||||
- The backward FMHA uses `quantize()` + reshape, but our element counts differ because the QK tile is 128×128 while P only needs 128×64
|
||||
|
||||
#### Next Steps for Stage B
|
||||
|
||||
1. **Fix the register data flow** — properly subselect the P-relevant 64 BF16 columns from the 128 FP32 load columns, or use the backward FMHA's PdO MMA tiler (M=128, N=64) instead of (M=128, N=128)
|
||||
2. **Verify A-fragment store roundtrip** — write known BF16 values via A-fragment store, have PV MMA read them back via A-fragment, confirm the physical TMEM addresses match
|
||||
3. **Once data flow is correct, add online softmax** (Stage C)
|
||||
|
||||
### 🔨 Stage C: Online Softmax — AFTER B
|
||||
|
||||
@@ -168,9 +168,20 @@ After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast
|
||||
|
||||
## Critical APIs & Lessons
|
||||
|
||||
### PipelineUmmaAsync consumer group size — THE MAY 20 BUG
|
||||
### C-fragment ≠ A-fragment TMEM Physical Layout — THE MAY 20-21 FINDING
|
||||
|
||||
**For `Agent.Thread` groups in `PipelineUmmaAsync`: use thread count, NOT warp count.**
|
||||
**The St32x32bOp with C-fragment composition writes to C-layout physical TMEM addresses. The PV MMA reads from A-layout physical TMEM addresses. These are DIFFERENT physical locations for the same logical (M,K) position.**
|
||||
|
||||
For the softmax to work, P must be written to TMEM using the A-fragment's physical layout, not the C-fragment's. The backward FMHA does this correctly by:
|
||||
1. Creating the store destination with A-fragment layout + recast C-fragment iterator
|
||||
2. Using a BF16 St32x32bOp atom
|
||||
3. True BF16 register (not FP32 recast) via quantize() pattern
|
||||
|
||||
### Forward FMHA Recast Pattern — FP16 ONLY
|
||||
|
||||
The `cute.recast_ptr` + `.store(v.to(FP16))` pattern for packing 2×16-bit into 1×FP32 register is validated for FP16 only. BF16 is rejected in forward FMHA. The BF16 recast produces zero output when writing to the same TMEM region as the MMA output, and NaN when writing to a different region read via A-fragment.
|
||||
|
||||
### PipelineUmmaAsync consumer group size — thread count, NOT warp count
|
||||
|
||||
```python
|
||||
# WRONG (caused CUDA_ERROR_LAUNCH_FAILED):
|
||||
@@ -180,18 +191,32 @@ consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 4) # warp count
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(softmax_warp_ids)) # thread count
|
||||
```
|
||||
|
||||
This applies to ALL PipelineUmmaAsync consumers where the consumer is multiple warps. fmha.py line 671: `self.threads_per_warp * len(self.softmax0_warp_ids) = 32 * 4 = 128`.
|
||||
|
||||
**Note:** The earlier README incorrectly stated that warp count was correct. That was wrong. The `Agent.Thread` agent type measures group size in threads.
|
||||
|
||||
### TMEM offset arithmetic
|
||||
|
||||
- `find_tmem_tensor_col_offset(fragment)` — returns physical TMEM column count (with 0x8000 tag for A-fragments)
|
||||
- QK accumulator C fragment: 128 TMEM columns
|
||||
- PV A-fragment: offset 0x8020 = tag(0x8000) + col(32) — the 0x8000 is a TMEM memory-space identifier
|
||||
- P OVERLAPS S in TMEM — P is written at column 32 within the S region (C-layout columns 0..127)
|
||||
- `tOrP0 = cute.make_tensor(tOrP.iterator + acc_dtype.width // q_dtype.width * tmem_p0_offset, tOrP.layout)` — A-fragment offset scaled by dtype width ratio (F32/BF16 = 2)
|
||||
|
||||
### A-fragment iterator must use recast C-fragment pointer
|
||||
|
||||
When creating the P tensor for PV MMA's A-operand, the iterator must be the C-fragment's iterator recast to BF16:
|
||||
```python
|
||||
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
```
|
||||
Without the recast, the A-fragment addresses are computed from an FP32 pointer base, giving wrong physical TMEM addresses (illegal memory access crash).
|
||||
|
||||
### V SMEM aliasing (K and V share SMEM)
|
||||
|
||||
```python
|
||||
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
```
|
||||
|
||||
### `make_trivial_tiled_mma` has two overloads
|
||||
|
||||
```python
|
||||
@@ -204,16 +229,6 @@ make_trivial_tiled_mma(ab_dtype, a_leading_mode, b_leading_mode,
|
||||
acc_dtype, cta_group, mma_tiler_mn, a_source=SMEM)
|
||||
```
|
||||
|
||||
### V SMEM aliasing (K and V share SMEM)
|
||||
|
||||
```python
|
||||
# K and V share the same SMEM buffer, but with different layouts:
|
||||
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, b_dtype, 1)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
```
|
||||
|
||||
### Other APIs discovered from Stage A
|
||||
|
||||
1. **`cute.Tensor` API** — `cutlass_torch.from_dlpack(t).mark_layout_dynamic(leading_dim=...)`
|
||||
@@ -234,12 +249,13 @@ tCrV = pv_mma.make_fragment_B(sV)
|
||||
- **vLLM repo**: `/root/dsv4-nvfp4-workspace/vllm` (modified for Blackwell)
|
||||
- **Pseudocode**: `/root/fragile-kernel-example/README.md` — authoritative per-tile attention flow
|
||||
- **fmha.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py`
|
||||
- **fmha_bwd.py reference**: `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha_bwd.py`
|
||||
|
||||
## 4-Stage Build Plan
|
||||
|
||||
| Stage | Goal | Status |
|
||||
|-------|------|--------|
|
||||
| A | Bare Q@K^T via tcgen05.mma → TMEM → GMEM | ✅ COMPLETE |
|
||||
| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 Runs without crash, identity softmax produces wrong output |
|
||||
| B | Two MMAs + identity softmax (validates TMEM A operand, shared KV, layout transform, barrier ordering) | 🔨 A-fragment store compiles, register data flow needs fixing |
|
||||
| C | Online softmax between MMA1 and MMA2 (the hard part) | ⬜ TODO |
|
||||
| D | FP8 paged KV gather + dequant (replace BF16 TMA load) | ⬜ TODO |
|
||||
|
||||
85
tests/diag_layouts.py
Normal file
85
tests/diag_layouts.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
|
||||
a_dtype = BFloat16; b_dtype = BFloat16
|
||||
from cutlass.cute.nvgpu import OperandMajorMode
|
||||
a_major = OperandMajorMode.K
|
||||
b_major = OperandMajorMode.K
|
||||
mma_tiler_mn = (128, 128)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
a_dtype, b_dtype, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
a_dtype, b_dtype, OperandMajorMode.K, b_major, Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(mma_tiler_mn)
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (*mma_tiler_mn, pv_inst_k * 4)
|
||||
|
||||
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + Float32.width // BFloat16.width * 32,
|
||||
tOrP.layout)
|
||||
|
||||
print('=== Layout diagnostics ===')
|
||||
print('tStS.layout:', tStS.layout)
|
||||
print('tStS.size:', cute.size(tStS))
|
||||
print('tStS s_cols:', find_tmem_tensor_col_offset(tStS))
|
||||
print()
|
||||
print('tOtO.layout:', tOtO.layout)
|
||||
print('tOtO.size:', cute.size(tOtO))
|
||||
print('tOtO o_cols:', find_tmem_tensor_col_offset(tOtO))
|
||||
print()
|
||||
print('tOrP.layout:', tOrP.layout)
|
||||
print('tOrP.size:', cute.size(tOrP))
|
||||
print('tOrP0.layout:', tOrP0.layout)
|
||||
print('tOrP0.size:', cute.size(tOrP0))
|
||||
print()
|
||||
|
||||
tilePlikeFP32 = 128 * 16 // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
print('tStS_P_layout:', tStS_P_layout)
|
||||
print()
|
||||
|
||||
# LOAD
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
|
||||
thr_load = tiled_tmem_load.get_slice(0)
|
||||
cS = cute.make_identity_tensor((128, 128))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
print('LOAD tTMEM_LOADcS.shape:', tTMEM_LOADcS.shape)
|
||||
print('LOAD per-thread elements:', cute.size(tTMEM_LOADcS))
|
||||
|
||||
# STORE (composition)
|
||||
tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(0)
|
||||
tTMEM_STOREcS = thr_store.partition_S(cute.make_identity_tensor(tStS_P.shape))
|
||||
print('STORE tTMEM_STOREcS.shape:', tTMEM_STOREcS.shape)
|
||||
print('STORE per-thread elements:', cute.size(tTMEM_STOREcS))
|
||||
|
||||
# What about the tOrP0 shape for store?
|
||||
print()
|
||||
print('tOrP0.shape:', tOrP0.shape if hasattr(tOrP0, 'shape') else 'N/A')
|
||||
# tOrP0 is BF16 so we'd need a BF16 store atom - but cute.copy requires equal bit widths
|
||||
# The F32 store to a BF16 target doesn't work either
|
||||
# This is the fundamental tension
|
||||
175
tests/test_afrag_roundtrip.py
Normal file
175
tests/test_afrag_roundtrip.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Test: ld FP32 from S0, st BF16 to A-fragment layout tdVrP,
|
||||
ld BF16 back from tdVrP, epi the result.
|
||||
If this works, the A-fragment store is correct and the issue is in the PV MMA."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class AFragRoundtrip:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols # Only need S region
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
# A-fragment for pv_mma
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
tdVrP = cute.make_tensor(tOrP.iterator, tOrP.layout)
|
||||
# TMEM ld (C-fragment)
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
# TMEM st (A-fragment layout, BF16)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tStP = thr_st.partition_D(tdVrP)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
|
||||
# EPILOGUE WARPS: ld FP32 → BF16 → st A-frag → ld A-frag BF16 → FP32 → epi
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
# 1. ld FP32 from S0
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# 2. Convert FP32 → BF16
|
||||
rBf16 = cute.make_rmem_tensor(tLdcS.shape, self.q_dtype)
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rBf16[i] = rLd[i].to(self.q_dtype)
|
||||
# 3. st BF16 to A-fragment layout
|
||||
cute.copy(tiled_st, rBf16, tStP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
# 4. Store to A-frag done. Check if S0 epi still works.
|
||||
si_handle.release()
|
||||
tCtS0 = cute.make_tensor(tmem_ptr, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtS0, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42); m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = AFragRoundtrip(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('A-frag roundtrip: cos={:.6f} (expect 0.999 from Stage A)'.format(cos))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
216
tests/test_b_afrag2.py
Normal file
216
tests/test_b_afrag2.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""Stage B: Store P via A-fragment layout with recast C-fragment iterator.
|
||||
|
||||
Matching the backward FMHA pattern exactly:
|
||||
1. tOrP = pv_thr.make_fragment_A(tP)[None,None,None,0] (A-fragment layout)
|
||||
2. tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=BF16) (C-fragment base, recast to BF16)
|
||||
3. tdVrP = cute.make_tensor(tdVrP_iter + offset, tOrP.layout)
|
||||
4. make_tmem_copy(St32x32bOp(Repetition(8)), BF16, tdVrP)
|
||||
5. Store BF16 registers to tdVrP
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StageBAfrag2:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 0
|
||||
self.tmem_o0_offset = self.s_cols * 2
|
||||
self.tmem_alloc_cols = 512
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
# ── P A-fragment (backward FMHA pattern) ──
|
||||
# 1. Get A-fragment layout from pv_mma
|
||||
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
# 2. Recast C-fragment iterator to BF16 (matching backward FMHA line 962)
|
||||
tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
# 3. Create store target with A-fragment layout + recast iterator
|
||||
# The offset for P within TMEM: qk_acc_dtype.width / q_dtype.width * tmem_p0_offset
|
||||
# But since we recast to BF16, the offset should be in BF16 units
|
||||
tdVrP = cute.make_tensor(
|
||||
tdVrP_iter + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
# PV MMA's A-fragment (for reading)
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.s_cols, tOrP.layout)
|
||||
# ── TMEM LOAD from C-fragment ──
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
# ── TMEM STORE via A-fragment layout (backward FMHA pattern) ──
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tStP = thr_st.partition_D(tdVrP)
|
||||
# Source identity for store (A-fragment shape)
|
||||
cS_P = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[2]))
|
||||
tScS_P = pv_thr.partition_A(cS_P)
|
||||
tStcS = thr_st.partition_S(tScS_P)
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
print(f'[A2] tdVrP.layout: {tdVrP.layout}')
|
||||
print(f'[A2] tOrP0.layout: {tOrP0.layout}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
|
||||
# PV MMA
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
|
||||
# SOFTMAX/EPILOGUE WARPS
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
# ld FP32 from S0
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# Convert FP32 → BF16 (backward-style: true BF16 register, not recast)
|
||||
rBf16 = cute.make_rmem_tensor(tStcS.shape, self.q_dtype)
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rBf16[i] = rLd[i].to(self.q_dtype)
|
||||
# Store BF16 to TMEM via A-fragment layout
|
||||
cute.copy(tiled_st, rBf16, tStP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42); m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBAfrag2(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Stage B A-frag2 (backward FMHA pattern): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
237
tests/test_bf16_elemwise.py
Normal file
237
tests/test_bf16_elemwise.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
|
||||
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class BF16Elemwise:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols * 2
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
|
||||
|
||||
# LD and ST on same layout
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
tStS = thr_st.partition_D(tStS1)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
tStcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld → FP32 st, NO recast
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Direct copy: ld register → st register (same shape since same layout)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
# Since ld and st have the same C-fragment layout and same identity tensor,
|
||||
# the register shapes should match. Copy element by element.
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rSt[i] = rLd[i].to(self.q_dtype).to(self.qk_acc_dtype)
|
||||
|
||||
cute.copy(tiled_st, rSt, tStS)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# epi reads S1
|
||||
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = BF16Elemwise(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('BF16 elemwise ld→st copy roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
237
tests/test_bf16_pack.py
Normal file
237
tests/test_bf16_pack.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
|
||||
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class BF16PackTest:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols * 2
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
|
||||
|
||||
# LD and ST on same layout
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
tStS = thr_st.partition_D(tStS1)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
tStcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld → FP32 st, NO recast
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Direct copy: ld register → st register (same shape since same layout)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
# Since ld and st have the same C-fragment layout and same identity tensor,
|
||||
# the register shapes should match. Copy element by element.
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rSt[i] = rLd[i].to(self.q_dtype).to(self.qk_acc_dtype)
|
||||
|
||||
cute.copy(tiled_st, rSt, tStS)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# epi reads S1
|
||||
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = BF16PackTest(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('BF16 elemwise ld→st to S1: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
242
tests/test_bf16_recast_full.py
Normal file
242
tests/test_bf16_recast_full.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""Test BF16 recast pattern with FULL C-fragment layout (no subview).
|
||||
ld from S0 (full 128x128), recast BF16, st to S1 (full 128x128 at offset 128).
|
||||
Since both ld and st use the same layout, the recast should work (shapes match).
|
||||
Then epi reads S1. If this works, the recast pattern IS correct for same-layout cases.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class BF16RecastFull:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols * 2
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
|
||||
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
tStS1_dst = thr_st.partition_D(tStS1)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
tStcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# BF16 recast pattern (same layout for ld and st, shapes match)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
rSt_e = cute.make_tensor(cute.recast_ptr(rSt.iterator, dtype=self.q_dtype), rLd.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(rLd) // frg_cnt
|
||||
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
|
||||
rSt_e_frg = cute.logical_divide(rSt_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
v = rLd_frg[None, j].load()
|
||||
rSt_e_frg[None, j].store(v.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_st, rSt, tStS1_dst)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# epi reads S1
|
||||
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = BF16RecastFull(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('BF16 recast full layout (ld S0, st S1, epi S1): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
239
tests/test_bf16_recast_simple.py
Normal file
239
tests/test_bf16_recast_simple.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Simplest BF16 recast test: ld FP32 from S0, use recast to convert,
|
||||
st back to S0, then epi reads S0. Single region, no subview."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class SimpleBF16Recast:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
|
||||
# ld and st on the SAME tensor (S0), same layout
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS0)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
tStS_dst = thr_st.partition_D(tStS0)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
tStcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# BF16 recast pattern (identity, no math)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
rSt_e = cute.make_tensor(cute.recast_ptr(rSt.iterator, dtype=self.q_dtype), rLd.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(rLd) // frg_cnt
|
||||
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
|
||||
rSt_e_frg = cute.logical_divide(rSt_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
v = rLd_frg[None, j].load()
|
||||
rSt_e_frg[None, j].store(v.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_st, rSt, tStS_dst)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# epi reads S0
|
||||
tCtS0 = cute.make_tensor(tmem_ptr, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS0, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = SimpleBF16Recast(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('BF16 recast same region (ld/st S0, epi S0): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
85
tests/test_error_pattern.py
Normal file
85
tests/test_error_pattern.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
BF16 Packing Diagnostic: Run identity softmax with K=V=randn,
|
||||
compare output vs reference to identify the error pattern.
|
||||
"""
|
||||
import torch, sys
|
||||
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from test_stage_b_v7 import StageBIdentitySoftmax
|
||||
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float()
|
||||
kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf # identity softmax: (Q @ K^T) @ V
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
|
||||
print(f'\nCosine: {cos:.6f}')
|
||||
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
|
||||
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
|
||||
|
||||
# Key diagnostic: compare Q@K^T stage (which we know is correct) vs PV stage
|
||||
# If Q@K^T is correct but PV is wrong, the output will show the PV error pattern
|
||||
# With identity softmax, P = Q@K^T. So output = P @ V = (Q@K^T) @ V
|
||||
# If V is being read wrong, the output will be P @ V_permuted
|
||||
|
||||
# Check: is the output a permutation of the reference?
|
||||
# Sort both and compare
|
||||
out_sorted = out.flatten().sort()[0]
|
||||
ref_sorted = ref.flatten().sort()[0]
|
||||
cos_sorted = torch.nn.functional.cosine_similarity(out_sorted.unsqueeze(0), ref_sorted.unsqueeze(0)).item()
|
||||
print(f'\nCosine after sorting: {cos_sorted:.6f}')
|
||||
print(f'(If ~1.0, output is a permutation of reference)')
|
||||
|
||||
# Check: does output match P @ V^T? (V transposed)
|
||||
ref_vt = qf @ kvf.T @ kvf.T # (Q@K^T) @ V^T
|
||||
cos_vt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_vt.flatten().unsqueeze(0)).item()
|
||||
print(f'Cosine with P @ V^T: {cos_vt:.6f}')
|
||||
|
||||
# Check: does output match P^T @ V? (P transposed)
|
||||
# P = Q@K^T, so P^T = K@Q^T
|
||||
ref_pt = kvf @ qf.T @ kvf
|
||||
cos_pt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_pt.flatten().unsqueeze(0)).item()
|
||||
print(f'Cosine with P^T @ V: {cos_pt:.6f}')
|
||||
|
||||
# Check: is output simply the Q@K^T scores (MMA2 produced identity)?
|
||||
# If MMA2 didn't run or produced P unchanged, output = P = Q@K^T
|
||||
qkt = qf @ kvf.T
|
||||
cos_qkt = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), qkt.flatten().unsqueeze(0)).item()
|
||||
print(f'Cosine with just Q@K^T: {cos_qkt:.6f}')
|
||||
|
||||
# Check: all output rows identical? (means P has identical rows, like all-ones)
|
||||
all_same = torch.allclose(out[0], out[1], atol=1e-3)
|
||||
print(f'All output rows identical: {all_same}')
|
||||
if not all_same:
|
||||
cos_r01 = torch.nn.functional.cosine_similarity(out[0].unsqueeze(0), out[1].unsqueeze(0)).item()
|
||||
print(f'Cosine between row 0 and row 1: {cos_r01:.6f}')
|
||||
|
||||
# Check: is output just V (P=I case)?
|
||||
cos_v = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), kvf.flatten().unsqueeze(0)).item()
|
||||
print(f'Cosine with V alone: {cos_v:.6f}')
|
||||
|
||||
# Print some output statistics
|
||||
print(f'\nOutput stats: min={out.min().item():.4f}, max={out.max().item():.4f}, mean={out.mean().item():.4f}')
|
||||
print(f'Ref stats: min={ref.min().item():.4f}, max={ref.max().item():.4f}, mean={ref.mean().item():.4f}')
|
||||
133
tests/test_packing_diag.py
Normal file
133
tests/test_packing_diag.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
BF16 Packing Diagnostic: Write specific F32 bit patterns to P TMEM via St32x32bOp.
|
||||
Then run PV MMA with V=identity. Output = P (as MMA reads it).
|
||||
|
||||
This reveals the BF16 packing order within F32 TMEM words.
|
||||
|
||||
Strategy:
|
||||
- Skip MMA1 entirely (no QK computation)
|
||||
- Write packed F32 values where low BF16 ≠ high BF16
|
||||
- Use K=V=identity so output[i,j] = P[i,j]
|
||||
- Compare output against both packing orders
|
||||
"""
|
||||
import torch, struct, sys
|
||||
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
||||
|
||||
# Use the existing diag test that writes P=all-ones, but modify to write specific patterns
|
||||
# The diag test is test_stage_b_diag.py - let me copy and modify v7 instead
|
||||
|
||||
# Actually, let me just use the EXISTING kernel (test_stage_b_v7) with:
|
||||
# 1. K = identity matrix
|
||||
# 2. The identity softmax running normally (it writes softmax(Q@K^T) as P)
|
||||
# 3. If Q is also identity, then Q@K^T = I@I = I, softmax(I) = softmax of identity
|
||||
# 4. That's not what we want either.
|
||||
|
||||
# Better: use the diag test (writes P=all-ones) with K=identity
|
||||
# Output should be all-ones * I = all-ones. That gives no new info.
|
||||
|
||||
# The REAL test: modify the kernel to write specific F32 values to the BF16 recast view
|
||||
# Instead of writing 1.0 BF16, write alternating different values
|
||||
|
||||
# Simplest approach: modify test_stage_b_diag.py to write j+1 in BF16 for each position
|
||||
# Then with V=identity, output[i,j] = j+1
|
||||
|
||||
# But modifying the kernel requires JIT changes. Let me use the existing v7 kernel
|
||||
# with a clever input choice instead.
|
||||
|
||||
# Approach: Use the identity softmax (which writes Q@K^T scores as P)
|
||||
# With Q=randn and K=identity: Q@K^T = Q (since K=I)
|
||||
# Then P = softmax(Q) and output = softmax(Q) @ V
|
||||
# With V=I: output = softmax(Q)
|
||||
# This tests the FULL pipeline but doesn't isolate packing.
|
||||
|
||||
# Let me just write the packing test kernel from scratch, minimally.
|
||||
# It only needs: TMA load V, write F32 to P TMEM, PV MMA, epilogue.
|
||||
|
||||
# Actually, the simplest isolation test:
|
||||
# Use the existing test with P=all-ones and V=K (K=randn)
|
||||
# We already know cos=0.08 for this. The 0.08 is not 0 or 1.
|
||||
# If I use V where V[j] = 1 if j even, 0 if j odd, then:
|
||||
# With correct packing: output = (number of ones in even K positions) per row
|
||||
# This doesn't help because P=all-ones means all K positions are 1.0
|
||||
|
||||
# I think the key issue is that with P=all-ones (cos=0.08), the output should be
|
||||
# EXACTLY sum(V, dim=0) for each row. Let me compare more carefully.
|
||||
|
||||
# Let me run the EXISTING P=all-ones diag test and compare the output values
|
||||
# against the reference. The PATTERN of errors will tell us about packing.
|
||||
|
||||
print("Running existing P=all-ones diag with K=V=randn")
|
||||
print("Comparing output vs reference to identify the error pattern")
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass import Float32, BFloat16
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
kvf = kv[:,:,0].float()
|
||||
ref = torch.ones(128, 128, dtype=torch.float32, device='cuda') @ kvf
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
from test_stage_b_diag import StageBDiag
|
||||
kernel = StageBDiag(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling diag test...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
|
||||
print(f'\nCosine: {cos:.6f}')
|
||||
print(f'Output row 0[:8]: {out[0,:8].tolist()}')
|
||||
print(f'Ref row 0[:8]: {ref[0,:8].tolist()}')
|
||||
print(f'Ratio out/ref row 0[:8]: {(out[0,:8] / ref[0,:8]).tolist()}')
|
||||
|
||||
# Check if output is a scaled version of the reference
|
||||
ratio = out[0] / ref[0]
|
||||
ratio_clean = ratio[torch.isfinite(ratio)]
|
||||
print(f'Ratio mean: {ratio_clean.mean().item():.6f}, std: {ratio_clean.std().item():.6f}')
|
||||
|
||||
# Check if odd/even columns have different ratios
|
||||
ratio_even = (out[0, 0::2] / ref[0, 0::2])[torch.isfinite(out[0, 0::2] / ref[0, 0::2])]
|
||||
ratio_odd = (out[0, 1::2] / ref[0, 1::2])[torch.isfinite(out[0, 1::2] / ref[0, 1::2])]
|
||||
print(f'Even col ratio mean: {ratio_even.mean().item():.6f}')
|
||||
print(f'Odd col ratio mean: {ratio_odd.mean().item():.6f}')
|
||||
|
||||
# Check if output matches a different V interpretation
|
||||
# What if the V SMEM is being read in a different column order?
|
||||
# Compute: what if V columns were permuted?
|
||||
# With P=all-ones, output[i] = sum(V[:,j]) for each j
|
||||
# This is the column sum of V, broadcast to all rows
|
||||
# If V is read in a different order, the output would be the sum of
|
||||
# differently-ordered V columns, but still all the same sum
|
||||
# So output should be uniform across columns = sum of all V values per column
|
||||
# But the output IS uniform (all rows identical). The VALUES are wrong.
|
||||
# That means the SUM is wrong, which means the V values being read are wrong.
|
||||
|
||||
# Let me check: does the output column structure match any known transform of V?
|
||||
print(f'\nV (K) row 0[:8]: {kvf[0,:8].tolist()}')
|
||||
print(f'V col sums (reference): {ref[0,:4].tolist()}')
|
||||
print(f'Output col 0[:4]: {out[0,:4].tolist()}')
|
||||
|
||||
# What if V is being read transposed?
|
||||
# sum(V[j,:]) per column j = row sums of V
|
||||
row_sums = kvf.sum(dim=1)
|
||||
col_sums = kvf.sum(dim=0)
|
||||
print(f'V row sums[:4]: {row_sums[:4].tolist()}')
|
||||
print(f'V col sums[:4]: {col_sums[:4].tolist()}')
|
||||
119
tests/test_pair_swap.py
Normal file
119
tests/test_pair_swap.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
BF16 Pair-Swap Test: compare kernel output against reference with V rows
|
||||
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
|
||||
"""
|
||||
import torch, sys
|
||||
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from test_stage_b_v7 import StageBIdentitySoftmax
|
||||
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float()
|
||||
kvf = kv[:,:,0].float()
|
||||
|
||||
# Standard reference: P @ V where P = Q @ K^T
|
||||
ref = qf @ kvf.T @ kvf
|
||||
|
||||
# Pair-swapped reference: if BF16 within each F32 are swapped,
|
||||
# MMA reads K[2j] and K[2j+1] swapped, which means V rows are swapped in pairs
|
||||
# Output = P @ V_with_row_pairs_swapped
|
||||
V_swap = kvf.clone()
|
||||
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
|
||||
ref_swap = qf @ kvf.T @ V_swap
|
||||
|
||||
# Also try: just the P@K^T with V where every 2 consecutive rows are swapped
|
||||
# but starting from different offsets
|
||||
# Try 1-based offset: swap rows (1,2), (3,4), ...
|
||||
V_swap1 = kvf.clone()
|
||||
V_swap1[1::2], V_swap1[2::2] = kvf[2::2].clone(), kvf[1::2].clone()
|
||||
V_swap1[0] = kvf[0] # row 0 unchanged
|
||||
ref_swap1 = qf @ kvf.T @ V_swap1
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:,:,0].float()
|
||||
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
|
||||
cos_swap1 = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap1.flatten().unsqueeze(0)).item()
|
||||
|
||||
print(f'\nCosine with standard ref: {cos_ref:.6f}')
|
||||
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
|
||||
print(f'Cosine with offset-1 swapped: {cos_swap1:.6f}')
|
||||
|
||||
# Also try: what if the entire P row is shifted by some amount?
|
||||
# Or what if P columns are reordered according to the A-fragment's K partitioning?
|
||||
# The A-fragment has K laid out as (16 inner, 4 outer chunks of 16)
|
||||
# If the store writes columns sequentially but the MMA reads them in the
|
||||
# chunked order, we'd get a specific permutation
|
||||
|
||||
# The A-fragment K layout: (col, k0, k1) where col=0..15, k0=0..3, k1=0..1
|
||||
# K position = col + 16*k0 + 64*k1
|
||||
# If the store writes K sequentially (0,1,2,...,127) but the MMA reads
|
||||
# in the fragment order, the mapping would be:
|
||||
# Fragment index (col, k0, k1) -> K position (col + 16*k0 + 64*k1)
|
||||
# But this is sequential (0,1,2,...), so no permutation.
|
||||
# UNLESS the store doesn't fill the fragment's K blocks correctly.
|
||||
|
||||
# Try: what if only the FIRST 64 K values (not 128) are filled?
|
||||
# The store writes 64 F32 = 128 BF16. But what if the MMA's A-fragment
|
||||
# K dimension is 64 (not 128)? Then only 64 K values have P, rest are garbage.
|
||||
# But we already verified nblk_pv=4 and the fragment covers 128 BF16.
|
||||
|
||||
# Try: what if the 64 F32 columns in the store map to 128 BF16 K positions
|
||||
# but with the packing where BF16[K] and BF16[K+1] come from the same F32 word
|
||||
# and K is determined by the A-fragment's partition order?
|
||||
|
||||
# The A-fragment reads K in groups of 16 (inner K), then 4 outer groups, then 2 BF16 per F32
|
||||
# So K values are accessed as: [group0_bf16_0, group0_bf16_1, group1_bf16_0, group1_bf16_1, ...]
|
||||
# where group = (col, k0, k1) and bf16_0/bf16_1 are the two BF16 in each F32 word
|
||||
|
||||
# This is getting complex. Let me just check the specific permutation by
|
||||
# matching output columns to reference columns for row 0.
|
||||
|
||||
# For each output column j, find which reference column it matches
|
||||
# (using the entire row as a signature)
|
||||
print('\nTrying to identify the column permutation for row 0...')
|
||||
# Since all values might not be unique, use the dot product with a known vector
|
||||
# Actually, the simplest: for output[0, j], check if it equals ref[0, k] for any k
|
||||
# But we need a unique signature. Let's use the output of the ENTIRE row.
|
||||
|
||||
# Compare output row i with reference rows
|
||||
for i in [0, 1, 2]:
|
||||
best_cos = 0
|
||||
best_j = -1
|
||||
for j in range(128):
|
||||
c = torch.nn.functional.cosine_similarity(out[i].unsqueeze(0), ref[j].unsqueeze(0)).item()
|
||||
if c > best_cos:
|
||||
best_cos = c
|
||||
best_j = j
|
||||
print(f' Output row {i} best matches ref row {best_j} (cos={best_cos:.4f})')
|
||||
|
||||
# Also: compare output column j with reference columns
|
||||
print('\nColumn matching (output col j vs ref col k):')
|
||||
for j in [0, 1, 2, 3, 4, 5, 6, 7]:
|
||||
best_cos = 0
|
||||
best_k = -1
|
||||
for k in range(128):
|
||||
c = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
|
||||
if c > best_cos:
|
||||
best_cos = c
|
||||
best_k = k
|
||||
print(f' Output col {j} best matches ref col {best_k} (cos={best_cos:.4f})')
|
||||
95
tests/test_pair_swap2.py
Normal file
95
tests/test_pair_swap2.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
BF16 Pair-Swap Test: compare kernel output against reference with V rows
|
||||
swapped in even/odd pairs (simulating BF16 packing swap within F32 words).
|
||||
Also tries to identify the exact column permutation.
|
||||
"""
|
||||
import torch, sys
|
||||
sys.path.insert(0, '/root/dsv4-nvfp4-workspace/kernel/tests')
|
||||
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from test_stage_b_v7 import StageBIdentitySoftmax
|
||||
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:,:,0].float()
|
||||
kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
|
||||
# Pair-swapped: V rows 0↔1, 2↔3, 4↔5, ...
|
||||
V_swap = kvf.clone()
|
||||
V_swap[0::2], V_swap[1::2] = kvf[1::2].clone(), kvf[0::2].clone()
|
||||
ref_swap = qf @ kvf.T @ V_swap
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:,:,0].float()
|
||||
cos_ref = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
cos_swap = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref_swap.flatten().unsqueeze(0)).item()
|
||||
|
||||
print(f'\nCosine with standard ref: {cos_ref:.6f}')
|
||||
print(f'Cosine with pair-swapped ref: {cos_swap:.6f}')
|
||||
|
||||
# Column permutation identification
|
||||
print('\nColumn permutation (output col j matches ref col k):')
|
||||
perm = []
|
||||
for j in range(min(16, 128)):
|
||||
best_cos = 0
|
||||
best_k = -1
|
||||
for k in range(128):
|
||||
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
|
||||
if c_val > best_cos:
|
||||
best_cos = c_val
|
||||
best_k = k
|
||||
perm.append(best_k)
|
||||
if j < 16:
|
||||
print(f' out col {j} -> ref col {best_k} (cos={best_cos:.4f})')
|
||||
|
||||
# Check if permutation has a pattern
|
||||
print(f'\nPermutation (first 16): {perm}')
|
||||
diffs = [perm[j+1] - perm[j] for j in range(len(perm)-1)]
|
||||
print(f'Permutation diffs: {diffs}')
|
||||
|
||||
# Check: is perm[j] = j with every pair swapped? (0→1, 1→0, 2→3, 3→2, ...)
|
||||
expected_pair_swap = [j^1 for j in range(128)] # XOR with 1 swaps even/odd
|
||||
matches_pair_swap = all(perm[j] == expected_pair_swap[j] for j in range(len(perm)))
|
||||
print(f'Matches pair-swap (0↔1, 2↔3, ...): {matches_pair_swap}')
|
||||
|
||||
# Also check the full permutation for all 128 columns
|
||||
full_perm = []
|
||||
for j in range(128):
|
||||
best_cos = 0
|
||||
best_k = -1
|
||||
for k in range(128):
|
||||
c_val = torch.nn.functional.cosine_similarity(out[:, j].unsqueeze(0), ref[:, k].unsqueeze(0)).item()
|
||||
if c_val > best_cos:
|
||||
best_cos = c_val
|
||||
best_k = k
|
||||
full_perm.append(best_k)
|
||||
|
||||
print(f'\nFull permutation: {full_perm}')
|
||||
# Check if it's pair-swap for all columns
|
||||
all_pair_swap = all(full_perm[j] == (j^1) for j in range(128))
|
||||
print(f'All columns match pair-swap: {all_pair_swap}')
|
||||
|
||||
# If not pair-swap, what's the pattern?
|
||||
if not all_pair_swap:
|
||||
# Check if it's a 16-element block permutation
|
||||
for block in range(8):
|
||||
block_perm = full_perm[block*16:(block+1)*16]
|
||||
print(f' Block {block}: {block_perm}')
|
||||
237
tests/test_recast_minimal.py
Normal file
237
tests/test_recast_minimal.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
|
||||
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class RecastMinimal:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols * 2
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
|
||||
|
||||
# LD and ST on same layout
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
tStS = thr_st.partition_D(tStS1)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
tStcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld → FP32 st, NO recast
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Direct copy: ld register → st register (same shape since same layout)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
# Since ld and st have the same C-fragment layout and same identity tensor,
|
||||
# the register shapes should match. Copy element by element.
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rSt[i] = rLd[i]
|
||||
|
||||
cute.copy(tiled_st, rSt, tStS)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# epi reads S1
|
||||
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = RecastMinimal(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('BF16 recast minimal→st copy roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
210
tests/test_stage_b_afrag.py
Normal file
210
tests/test_stage_b_afrag.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Stage B: Store P via A-fragment layout, not C-fragment.
|
||||
|
||||
Key insight from Mike: the C-fragment store and A-fragment read use different
|
||||
physical TMEM address mappings. The fix is to construct the TMEM store
|
||||
using the A-fragment layout (from p_tmem_s / make_fragment_A).
|
||||
|
||||
Steps:
|
||||
1. Q@K^T → TMEM (C-fragment, offset 0)
|
||||
2. ld scores from TMEM (C-fragment)
|
||||
3. Convert FP32→BF16
|
||||
4. st P to TMEM using A-fragment layout (p_tmem_s / tOrP0)
|
||||
5. PV MMA reads from TMEM using A-fragment (same layout = same physical addresses)
|
||||
6. Epilogue writes output
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StageBAfrag:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
# TMEM layout: S0 (scores) at offset 0, P (A-fragment) at offset 32, O at offset s_cols
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = self.s_cols
|
||||
self.tmem_alloc_cols = self.s_cols + self.o_cols
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
# TMEM tensors
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
# P A-fragment (matching fmha exactly)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
# ── TMEM LOAD from C-fragment ──
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
# ── TMEM STORE to A-fragment layout ──
|
||||
# Use St16x128bOp with BF16 for the A-fragment layout
|
||||
# (St32x32bOp only works with C-fragment layout, not A-fragment)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St16x128bOp(tcgen05.copy.Repetition(16)), self.q_dtype)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tP)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tStP = thr_st.partition_D(tP)
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
|
||||
# PV MMA
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
|
||||
# SOFTMAX/EPILOGUE WARPS
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
# ld FP32 from S0
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# Convert FP32 → BF16 into a true BF16 register (backward-style)
|
||||
rBf16 = cute.make_rmem_tensor(tLdcS.shape, self.q_dtype)
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rBf16[i] = rLd[i].to(self.q_dtype)
|
||||
# Store BF16 to TMEM via A-fragment layout
|
||||
cute.copy(tiled_st, rBf16, tStP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42); m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBAfrag(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Stage B A-frag store: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
217
tests/test_stage_b_afrag2.py
Normal file
217
tests/test_stage_b_afrag2.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Stage B: Store P via A-fragment layout with recast C-fragment iterator.
|
||||
|
||||
Matching the backward FMHA pattern exactly:
|
||||
1. tOrP = pv_thr.make_fragment_A(tP)[None,None,None,0] (A-fragment layout)
|
||||
2. tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=BF16) (C-fragment base, recast to BF16)
|
||||
3. tdVrP = cute.make_tensor(tdVrP_iter + offset, tOrP.layout)
|
||||
4. make_tmem_copy(St32x32bOp(Repetition(8)), BF16, tdVrP)
|
||||
5. Store BF16 registers to tdVrP
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StageBAfrag2:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 0
|
||||
self.tmem_o0_offset = self.s_cols * 2
|
||||
self.tmem_alloc_cols = 512
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
# ── P A-fragment (backward FMHA pattern) ──
|
||||
# 1. Get A-fragment layout from pv_mma
|
||||
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
# 2. Recast C-fragment iterator to BF16 (matching backward FMHA line 962)
|
||||
tdVrP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
# 3. Create store target with A-fragment layout + recast iterator
|
||||
# The offset for P within TMEM: qk_acc_dtype.width / q_dtype.width * tmem_p0_offset
|
||||
# But since we recast to BF16, the offset should be in BF16 units
|
||||
tdVrP = cute.make_tensor(
|
||||
tdVrP_iter + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
# PV MMA's A-fragment (for reading)
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator, tOrP.layout)
|
||||
# ── TMEM LOAD from C-fragment ──
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
# ── TMEM STORE via A-fragment layout (backward FMHA pattern) ──
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tStP = thr_st.partition_D(tdVrP)
|
||||
# Source identity for store (A-fragment shape)
|
||||
cS_P = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[2]))
|
||||
tScS_P = pv_thr.partition_A(cS_P)
|
||||
tStcS = thr_st.partition_S(tScS_P)
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
print(f'[A2] tdVrP.layout: {tdVrP.layout}')
|
||||
print(f'[A2] tOrP0.layout: {tOrP0.layout}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
|
||||
# PV MMA
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
|
||||
# SOFTMAX/EPILOGUE WARPS
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
# ld FP32 from S0
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# Convert FP32 → BF16 (backward-style: true BF16 register, not recast)
|
||||
rBf16 = cute.make_rmem_tensor(tStcS.shape, self.q_dtype)
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rBf16[i] = cutlass.BFloat16(1.0)
|
||||
# Store BF16 to TMEM via A-fragment layout
|
||||
# SKIP STORE
|
||||
#cute.copy(tiled_st, rBf16, tStP)
|
||||
#cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42); m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBAfrag2(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Stage B A-frag2 (backward FMHA pattern): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
384
tests/test_stage_b_diag.py
Normal file
384
tests/test_stage_b_diag.py
Normal file
@@ -0,0 +1,384 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# 3. DIAGNOSTIC: Print per-thread element counts
|
||||
if tidx == 0:
|
||||
print(f"[DIAG] LOAD per-thread elements: {cute.size(tTMEM_LOADcS)}")
|
||||
print(f"[DIAG] STORE per-thread elements: {cute.size(tTMEM_STOREcS)}")
|
||||
|
||||
# 4. Wait for scores (MMA1 must complete before we touch TMEM)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 5. DIAGNOSTIC: Skip the load entirely. Write 1.0 to P via store atom.
|
||||
# If store atom works: output should be correlated with P@V.
|
||||
# P=all-ones means MMA2 computes sum of V columns per row.
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
# Need load rmem tensor for its layout (fmha.py pattern: recast uses load layout)
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
for j in cutlass.range(cute.size(tTMEM_STORErS_x4_e), unroll_full=True):
|
||||
tTMEM_STORErS_x4_e[j] = self.q_dtype(1.0)
|
||||
|
||||
# 6. Store P=all-ones into A-layout TMEM
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = torch.ones(128, 128, dtype=torch.float32, device="cuda") @ kvf # P=all-ones @ V
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B DIAG: P=all-ones (store atom test)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
print(' Output row 0[:8]:', out[0,:8].tolist())
|
||||
print(' Output row 1[:8]:', out[1,:8].tolist())
|
||||
print(' Output row 63[:8]:', out[63,:8].tolist())
|
||||
print(' Ref row 0[:8]:', ref[0,:8].tolist())
|
||||
print(' Output col 0[:8]:', out[:8,0].tolist())
|
||||
print(' Any zeros:', (out==0).sum().item(), 'of', out.numel())
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
207
tests/test_stage_b_final.py
Normal file
207
tests/test_stage_b_final.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Stage B final: ld FP32 from S0, BF16 recast, st to S1 (offset s_cols),
|
||||
PV MMA reads A-fragment from S1 at the appropriate offset.
|
||||
|
||||
The recast pattern WORKS when writing to a different TMEM region.
|
||||
PV MMA A-fragment reads from S1 with tmem_p0_offset adjustment.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StageBFinal:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
# S0 at offset 0, P0 (BF16 view) at offset 32 within S0, O0 at offset s_cols
|
||||
# But for the recast test, we write the FULL C-fragment to S1 (offset s_cols)
|
||||
# PV MMA A-fragment needs to read from S1. The A-fragment offset is:
|
||||
# tOrP0 starts at tStS.iterator + (F32_width/BF16_width) * tmem_p0_offset
|
||||
# For S1, we need A-fragment pointing to the start of S1.
|
||||
# A-fragment offset for column c: width_ratio * c
|
||||
# S1 starts at s_cols (128 FP32 columns). A-fragment for P at S1 start:
|
||||
# offset = (F32_width / BF16_width) * s_cols = 2 * 128 = 256
|
||||
# But tOrP layout has stride 64 for M, so base offset is in BF16 column units
|
||||
# Actually, tOrP0.iterator is the base of the A-fragment in the TMEM address space.
|
||||
# The A-fragment offset is (qk_acc_dtype.width / q_dtype.width) * tmem_p0_offset.
|
||||
# For our case, we want P at S1 start, so tmem_p0_offset should map to s_cols.
|
||||
# But tmem_p0_offset is in FP32 columns, and the A-fragment width ratio converts it.
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 0 # P starts at the beginning of S1
|
||||
self.tmem_o0_offset = self.s_cols # O region after S1
|
||||
self.tmem_alloc_cols = self.s_cols + self.o_cols # S1 + O
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2]); tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # S0: MMA output
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout) # S1: BF16 copy of scores (P)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
# P A-fragment pointing to S1 (offset s_cols)
|
||||
tP = cute.make_tensor(tStS.iterator + self.s_cols, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
# Since P is at the START of S1 (offset s_cols), no additional offset needed
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator, tOrP.layout)
|
||||
# TMEM ld/st atoms
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0); tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id)); thr_ld = tiled_ld.get_slice(sfw); thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0); tStS1_dst = thr_st.partition_D(tStS1)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])); tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS); tStcS = thr_st.partition_S(tScS)
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
|
||||
# PV MMA
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
|
||||
# SOFTMAX/EPILOGUE WARPS
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
# ld FP32 from S0
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# BF16 recast pattern (identity softmax: just FP32→BF16 packing)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
rSt_e = cute.make_tensor(cute.recast_ptr(rSt.iterator, dtype=self.q_dtype), rLd.layout)
|
||||
frg_cnt = 4; frg_tile = cute.size(rLd) // frg_cnt
|
||||
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
|
||||
rSt_e_frg = cute.logical_divide(rSt_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
v = rLd_frg[None, j].load()
|
||||
rSt_e_frg[None, j].store(v.to(self.q_dtype))
|
||||
cute.copy(tiled_st, rSt, tStS1_dst); cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42); m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBFinal(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Stage B Final (ld S0, BF16 recast, st S1, PV MMA): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
321
tests/test_stage_b_v10.py
Normal file
321
tests/test_stage_b_v10.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Stage B v10: Use FP32 ld→st on FULL (128,128) layout, not subview.
|
||||
The ld reads S0 (full C-fragment), the st writes S1 (full C-fragment at offset 128),
|
||||
then PV MMA reads S1 via A-fragment.
|
||||
|
||||
Key insight: The FP32 ld→st on the SAME layout works (cosine 0.999999).
|
||||
The BF16 recast pattern with DIFFERENT layout sizes is broken.
|
||||
For identity softmax, we need to write the data back to TMEM in a format
|
||||
that the PV MMA's A-fragment can read. Since the C-fragment and A-fragment
|
||||
both access the same physical TMEM, writing via C-fragment (store) and
|
||||
reading via A-fragment (PV MMA) should work IF the data is in the right place.
|
||||
|
||||
The pv_mma_tiler has N=128 (from mma_tiler_mn[1]) but P's K dimension = 128.
|
||||
Wait, pv_mma_tiler = (M, K, K_inner*4). For the PV MMA, A=P with K=128 (full KV dim).
|
||||
But p_tmem_s is make_smem_layout_a(pv_mma, pv_mma_tiler, BF16, 1) which gives the
|
||||
A-fragment layout for (128, 128, 64). The A-fragment covers K=64 at CTA level...
|
||||
Hmm, pv_mma_tiler[1] should be the V dimension (N in MMA terms), not K.
|
||||
|
||||
Actually, for PV MMA: P (A operand, M×K) × V (B operand, K×N) → O (C operand, M×N)
|
||||
With a_source=TMEM, the P operand's K dimension matches the CTA tile K.
|
||||
pv_mma_tiler = (128, 128, 64): M=128, N=128, K=64.
|
||||
So P is M×K = 128×64. That's the A-fragment shape.
|
||||
|
||||
So the A-fragment only reads 64 columns of TMEM (K=64 in 4 blocks of 16).
|
||||
The C-fragment for S has 128 columns. The P region starts at offset 32 (tmem_p0_offset).
|
||||
With s_cols=128 and tmem_p0_offset=32, the P data starts at column 32 and spans 64 columns
|
||||
(columns 32-95), which is within the S region (columns 0-127).
|
||||
|
||||
So the store should write to columns 32-95 of the S region. The C-fragment composition
|
||||
(128, tilePlikeFP32=64) with offset 32 does exactly this.
|
||||
|
||||
But we showed the subview store + recast doesn't work. Let me try: write the FULL
|
||||
C-fragment (128×128) at offset 128 (tStS1), and adjust the PV MMA's A-fragment
|
||||
offset so it reads from the right columns within that region.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StageBv10:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
self.o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = self.s_cols
|
||||
self.tmem_alloc_cols = self.s_cols + self.o_cols # 256
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P A-fragment
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# LD and ST copy atoms on the SAME layout (full C-fragment)
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS0) # SAME layout for ld and st
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
tStS_dst = thr_st.partition_D(tStS0) # St writes to S0 (same as MMA output)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
tStcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
# PV MMA
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# EPILOGUE WARPS: ld from S0, FP32→BF16 elementwise, st back to S0, then MMA warp uses it for PV
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld from S0
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# FP32 → BF16 → FP32 elementwise (identity softmax, no math, just dtype roundtrip)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rSt[i] = rLd[i].to(self.q_dtype).to(self.qk_acc_dtype)
|
||||
|
||||
# FP32 st back to S0 (overwrite scores with identity-softmaxed values)
|
||||
cute.copy(tiled_st, rSt, tStS_dst)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(ing_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBv10(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Stage B v10 (FP32 ld→BF16 elementwise→st, PV MMA): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
210
tests/test_stage_b_v11.py
Normal file
210
tests/test_stage_b_v11.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Stage B v11: Backward FMHA pattern exactly.
|
||||
|
||||
1. ld FP32 from S0 (C-fragment)
|
||||
2. Quantize FP32→BF16 (same register shape, .load()/.store())
|
||||
3. Reshape BF16 register to store partition shape
|
||||
4. St32x32bOp BF16 store to tdVrP (A-fragment layout, offset s_cols)
|
||||
5. PV MMA reads from tOrP0 (A-fragment, offset s_cols)
|
||||
6. Epilogue writes output
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StageBv11:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2; self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0); qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape); self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape); self.o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_o0_offset = self.s_cols * 2 # After S0 (128) + P (128 BF16 = 64 FP32, but A-frag offset uses width ratio)
|
||||
self.tmem_alloc_cols = 512 # Enough for S0 + P + O
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, LayoutEnum.from_tensor(a).mma_major_mode(), LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, LayoutEnum.from_tensor(b).mma_major_mode(), self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)); b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()); tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]; mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]; tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)), cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner); sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0); tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
pv_thr = pv_mma.get_slice(0); pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
# ── P A-fragment (backward FMHA pattern) ──
|
||||
tP_iter = cute.recast_ptr(tStS.iterator, dtype=self.q_dtype)
|
||||
tP = cute.make_tensor(tP_iter, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
# Apply offset for P region (after S0)
|
||||
p_bf16_offset = self.qk_acc_dtype.width // self.q_dtype.width * self.s_cols
|
||||
tdVrP = cute.make_tensor(tOrP.iterator + p_bf16_offset, tOrP.layout)
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator + p_bf16_offset, tOrP.layout)
|
||||
# ── TMEM LOAD from C-fragment ──
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
# ── TMEM STORE via A-fragment layout ──
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tdVrP)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tStP = thr_st.partition_D(tdVrP)
|
||||
tdVcST = pv_thr.partition_A(cS_id) # A-operand partition of identity
|
||||
tStcS = thr_st.partition_S(tdVcST)
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek); cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc(); ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st); qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek); nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store(); s0_handle.commit()
|
||||
# PV MMA
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
acc_pipe.producer_commit(acc_prod_st); acc_prod_st.advance(); acc_pipe.producer_tail(acc_prod_st)
|
||||
# SOFTMAX/EPILOGUE WARPS (backward FMHA quantize pattern)
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols); tmem.wait_for_alloc(); tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
# 1. ld FP32 from S0
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype); cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
# 2. Quantize FP32→BF16 (backward FMHA pattern)
|
||||
rBf16 = cute.make_rmem_tensor(rLd.shape, self.q_dtype)
|
||||
frg_cnt = 4; frg_tile = cute.size(rLd) // frg_cnt
|
||||
rLd_frg = cute.logical_divide(rLd, cute.make_layout(frg_tile))
|
||||
rBf16_frg = cute.make_tensor(rBf16.iterator, rLd_frg.layout)
|
||||
for i in cutlass.range(frg_tile, unroll_full=True):
|
||||
v = rLd_frg[None, i].load()
|
||||
rBf16_frg[None, i].store(v.to(self.q_dtype))
|
||||
# 3. Reshape BF16 to store partition shape
|
||||
rBf16_reshaped = cute.make_tensor(rBf16.iterator, cute.make_layout(tStcS.shape))
|
||||
# 4. Store BF16 to TMEM via A-fragment layout
|
||||
cute.copy(tiled_st, rBf16_reshaped, tStP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail(); tmem.relinquish_alloc_permit(); tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42); m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T @ kv[:,:,0].float()
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBv11(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream); torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Stage B v11 (backward FMHA pattern): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
398
tests/test_stage_b_v11b.py
Normal file
398
tests/test_stage_b_v11b.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
Stage B v11b: Identity Softmax - store to composition(tP, (128,128))
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
# FIX: PV tiler swaps N and K from QK (fmha.py: pv_mma_tiler = (M, QK_K, QK_N))
|
||||
# P is (M, QK_N) and V is (QK_N, D), so PV MMA has K=QK_N, N=D
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline: write to PV MMA A-layout (tP) composed to 2D
|
||||
# tP has correct A-fragment addresses. Compose to (128,128) for the store atom.
|
||||
tP_2d_layout = cute.composition(tP.layout, cute.make_layout((128, 128)))
|
||||
tP_2d = cute.make_tensor(tP.iterator + self.tmem_p0_offset, tP_2d_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tP_2d)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tP_2d)
|
||||
cP_2d = cute.make_identity_tensor(tP_2d.shape)
|
||||
tTMEM_STOREcP = thr_store.partition_S(cP_2d)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. IDENTITY: F32 → BF16
|
||||
tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_STOREtP.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErP_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErP.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErP_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Store into A-layout TMEM
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErP, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v11b: store to tP composed (128,128)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
408
tests/test_stage_b_v12.py
Normal file
408
tests/test_stage_b_v12.py
Normal file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Stage B v12: BF16 St16x128bOp store to tP A-layout, P=all-ones
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
# FIX: PV tiler swaps N and K from QK (fmha.py: pv_mma_tiler = (M, QK_K, QK_N))
|
||||
# P is (M, QK_N) and V is (QK_N, D), so PV MMA has K=QK_N, N=D
|
||||
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline: write to A-fragment TMEM via BF16 store atom
|
||||
# tP has the A-fragment layout with dense TMEM addresses.
|
||||
# St32x32bOp is for C-fragment (sparse addressing) — WRONG for A-fragment.
|
||||
# Try BF16 store atoms that match the A-fragment partition.
|
||||
# First, create the P store target with the p0 offset applied.
|
||||
# tP is F32 but represents BF16 positions. We need a BF16 view.
|
||||
tP_bf16 = cute.make_tensor(
|
||||
cute.recast_ptr(tP.iterator, dtype=self.q_dtype),
|
||||
p_tmem_s_bf16_layout if False else tP.layout) # placeholder
|
||||
# Actually, lets try make_tmem_copy with tP directly using F32 atom first
|
||||
# If that fails, well try BF16 atoms
|
||||
tP_offset = cute.make_tensor(tP.iterator + self.tmem_p0_offset, tP.layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St16x128bOp(tcgen05.copy.Repetition(8)), self.q_dtype)
|
||||
# tP_offset is F32-typed but we need BF16 for St16x128bOp
|
||||
# Create BF16 view of tP
|
||||
tP_bf16_offset = cute.make_tensor(
|
||||
cute.recast_ptr(tP.iterator, dtype=self.q_dtype),
|
||||
tP.layout)
|
||||
tP_bf16_with_offset = cute.make_tensor(
|
||||
tP_bf16_offset.iterator + self.tmem_p0_offset, tP_bf16_offset.layout)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tP_bf16_with_offset)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tP_bf16_with_offset)
|
||||
cP = cute.make_identity_tensor(tP_bf16_with_offset.shape)
|
||||
tTMEM_STOREcP = thr_store.partition_S(cP)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. DIAGNOSTIC: Just write P=all-ones directly (skip load + F32→BF16)
|
||||
# This tests whether the BF16 store atom + tP A-layout works
|
||||
tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_STOREtP.shape, self.q_dtype)
|
||||
for j in cutlass.range(cute.size(tTMEM_STORErP), unroll_full=True):
|
||||
tTMEM_STORErP[j] = self.q_dtype(1.0)
|
||||
|
||||
# 5. Store into A-fragment TMEM
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErP, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 6. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = torch.ones(128, 128, dtype=torch.float32, device="cuda") @ kvf # P=all-ones @ V
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v12: BF16 store to A-fragment layout')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
@@ -41,6 +41,11 @@ class StageBIdentitySoftmax:
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
@@ -80,7 +85,7 @@ class StageBIdentitySoftmax:
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_p0_offset = 32 # Original
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
@@ -117,6 +122,16 @@ class StageBIdentitySoftmax:
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
# Introspect PV MMA atom
|
||||
print(f"[ATOM] PV MMA type: {type(pv_mma)}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op if hasattr(pv_mma, "op") else "no op"}")
|
||||
print(f"[ATOM] PV MMA trait: {pv_mma._trait if hasattr(pv_mma, "_trait") else "no trait"}")
|
||||
print(f"[ATOM] PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f"[ATOM] QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
# Check a_src
|
||||
print(f"[ATOM] PV MMA op.a_src: {pv_mma.op.a_src}")
|
||||
print(f"[ATOM] QK MMA op.a_src: {qk_mma.op.a_src}")
|
||||
print(f"[ATOM] PV MMA op: {pv_mma.op}")
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
@@ -205,6 +220,10 @@ class StageBIdentitySoftmax:
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
@@ -220,16 +239,73 @@ class StageBIdentitySoftmax:
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
|
||||
445
tests/test_stage_b_v7_rep128.py
Normal file
445
tests/test_stage_b_v7_rep128.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
|
||||
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
|
||||
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
|
||||
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
|
||||
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. IDENTITY: F32 → BF16
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Store into A-layout
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
445
tests/test_stage_b_v7_rep16.py
Normal file
445
tests/test_stage_b_v7_rep16.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
|
||||
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
|
||||
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
|
||||
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
|
||||
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. IDENTITY: F32 → BF16
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Store into A-layout
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
445
tests/test_stage_b_v7_rep64.py
Normal file
445
tests/test_stage_b_v7_rep64.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
|
||||
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
|
||||
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
|
||||
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
|
||||
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. IDENTITY: F32 → BF16
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Store into A-layout
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
445
tests/test_stage_b_v7_rep8.py
Normal file
445
tests/test_stage_b_v7_rep8.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Stage B v7: Two MMAs + Identity Softmax with COMPUTED TMEM offsets.
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
print(f"[StageB] qk_mma.shape_mnk = {qk_mma.shape_mnk}")
|
||||
print(f"[StageB] pv_mma.shape_mnk = {pv_mma.shape_mnk}")
|
||||
print(f"[StageB] qk_mma_tiler = {self.qk_mma_tiler}")
|
||||
print(f"[StageB] pv_mma_tiler = {self.pv_mma_tiler}")
|
||||
print(f"[StageB] qk_inst_k = {qk_inst_k}, pv_inst_k = {pv_inst_k}")
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
print(f"[DIAG] tCrV.size = {cute.size(tCrV)}")
|
||||
print(f"[DIAG] tCrA.size = {cute.size(tCrA)}")
|
||||
print(f"[DIAG] tCrB.size = {cute.size(tCrB)}")
|
||||
print(f"[DIAG] nblk_qk (tCrA mode 2) = {cute.size(tCrA, mode=[2])}")
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
print(f'[TMEM] p_tmem_s: {p_tmem_s}')
|
||||
print(f'[TMEM] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[TMEM] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
# Check SMEM layout compatibility: K (b_smem_s) vs V (v_smem_s)
|
||||
print(f'[SMEM] b_smem_s.outer: {self.b_smem_s.outer}')
|
||||
print(f'[SMEM] v_smem_s.outer: {self.v_smem_s.outer}')
|
||||
print(f'[SMEM] b_smem_s.inner: {self.b_smem_s.inner}')
|
||||
print(f'[SMEM] v_smem_s.inner: {self.v_smem_s.inner}')
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
print(f'[DIAG] tStS.layout: {tStS.layout}')
|
||||
print(f'[DIAG] tStS.size: {cute.size(tStS)}')
|
||||
print(f'[DIAG] p_tmem_s.outer: {p_tmem_s.outer}')
|
||||
print(f'[DIAG] p_tmem_s.inner: {p_tmem_s.inner}')
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# Compute nblk_pv for diagnostics
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
nblk_qk = cute.size(tCrA, mode=[2])
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
# COMPREHENSIVE LAYOUT DUMP
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
|
||||
print(f'[LAYOUT] QK C-fragment tStS.layout: {tStS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS cosize: {cute.cosize(tStS.layout)}')
|
||||
print(f'[LAYOUT] QK C-fragment tStS.size: {cute.size(tStS)}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS.layout: {tScS.layout}')
|
||||
print(f'[LAYOUT] QK C-fragment tScS cosize: {cute.cosize(tScS.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.layout: {tOrP.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP cosize: {cute.cosize(tOrP.layout)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[LAYOUT] PV A-fragment tOrP0 cosize: {cute.cosize(tOrP0.layout)}')
|
||||
print(f'[LAYOUT] tP.layout: {tP.layout}')
|
||||
print(f'[LAYOUT] tP cosize: {cute.cosize(tP.layout)}')
|
||||
print(f'[LAYOUT] tStS_P (composed) layout: {tStS_P.layout}')
|
||||
print(f'[LAYOUT] tStS_P (composed) cosize: {cute.cosize(tStS_P.layout)}')
|
||||
print(f'[LAYOUT] tScS_P (composed) layout: {tScS_P.layout}')
|
||||
print(f'[LAYOUT] tScS_P (composed) cosize: {cute.cosize(tScS_P.layout)}')
|
||||
print(f'[LAYOUT] tOtO.layout: {tOtO.layout}')
|
||||
print(f'[LAYOUT] tOtO cosize: {cute.cosize(tOtO.layout)}')
|
||||
print(f'[LAYOUT] pv_mma_tiler: {self.pv_mma_tiler}')
|
||||
print(f'[LAYOUT] qk_mma_tiler: {self.qk_mma_tiler}')
|
||||
print(f'[LAYOUT] tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# DIAGNOSTIC: Compare tP (A-layout) vs tStS_P (composition)
|
||||
tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
print(f'[DIAG] tP.layout: {tP.layout}')
|
||||
print(f'[DIAG] tP.size: {cute.size(tP)}')
|
||||
print(f'[DIAG] tP.element_type: {tP.element_type if hasattr(tP, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'[DIAG] tStS_P.size: {cute.size(tStS_P)}')
|
||||
print(f'[DIAG] tStS_P.element_type: {tStS_P.element_type if hasattr(tStS_P, 'element_type') else 'N/A'}')
|
||||
print(f'[DIAG] tilePlikeFP32: {tilePlikeFP32}')
|
||||
print(f'[DIAG] tP and tStS_P same iterator? {tP.iterator == tStS_P.iterator if hasattr(tP, 'iterator') else 'cant compare'}')
|
||||
|
||||
print(f'[DIAG] nblk_pv = {nblk_pv}, nblk_qk = {nblk_qk}')
|
||||
print(f'[DIAG] tCrV.size = {cute.size(tCrV)}')
|
||||
print(f'[DIAG] tOrP0.size = {cute.size(tOrP0)}')
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── Identity softmax: C-layout → A-layout transform ──
|
||||
# 1. LOAD pipeline (reads scores from QK C-layout)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# 2. STORE pipeline (writes P in A-layout — same as fmha.py softmax_step)
|
||||
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
|
||||
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
|
||||
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
|
||||
|
||||
# 3. Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# 4. Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# 5. IDENTITY: F32 → BF16
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErS_x4_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# 6. Store into A-layout
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# 7. Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v7: (Q @ K^T) @ V with identity softmax (computed TMEM offsets)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
403
tests/test_stage_b_v8.py
Normal file
403
tests/test_stage_b_v8.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
Stage B v8: Fix the identity softmax by using the A-fragment layout
|
||||
for the TMEM store target instead of the C-fragment composition.
|
||||
|
||||
The bug in v7: tStS_P uses composition(tStS.layout, (128, tilePlikeFP32))
|
||||
which gives layout (128,64):(65536,1) — C-fragment strides.
|
||||
But the PV MMA reads from TMEM using the A-fragment layout
|
||||
((128,16),1,4):((64,1),0,16) — physical TMEM strides.
|
||||
|
||||
For the store to be read correctly by the PV MMA, the store target
|
||||
must use the same physical TMEM addressing as the A-fragment.
|
||||
|
||||
Key insight from CUTLASS source:
|
||||
Physical TMEM for M=128, BK=64, BF16 A from TMEM (K-major):
|
||||
tmem[dp=m, col=base_col + 16*mma_k + k_inner] for mma_k in 0..3, k_inner in 0..15
|
||||
|
||||
This means: A-fragment address = 64*m + k_inner + 16*mma_k
|
||||
C-fragment address for (m, col) = ??? (virtual layout, not physical)
|
||||
|
||||
The St32x32b copy atom with tStS_P (C-composition) writes to C-layout addresses.
|
||||
The PV MMA reads from A-layout addresses. These are different physical locations.
|
||||
|
||||
Fix: Use tP (from p_tmem_s, the A-fragment source layout) as the store target
|
||||
instead of tStS_P (the C-fragment composition). This ensures the store writes
|
||||
to physical TMEM addresses that the PV MMA's A-fragment will read correctly.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmaxV8:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
# TMEM tensors
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P tensor for PV MMA A-fragment
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# ── KEY FIX: Store target uses A-fragment layout (tP) not C-fragment composition ──
|
||||
# The store must write to physical TMEM addresses that the PV MMA reads via A-fragment.
|
||||
# tP has layout ((128,16),1,4,1):((64,1),0,16,0) — the A-fragment's physical TMEM layout.
|
||||
# We need the store target at tmem_p0_offset = 32 columns into the S region.
|
||||
# tP's iterator starts at tStS.iterator (base of S region).
|
||||
# tOrP0 starts at tStS.iterator + 2*32 (scaled by F32/BF16 width ratio for the A-fragment).
|
||||
# The store target should use the SAME layout as tP but with the p0 offset applied.
|
||||
tP_store = cute.make_tensor(
|
||||
tStS.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
p_tmem_s.outer)
|
||||
|
||||
print(f'[v8] tP.layout: {tP.layout}')
|
||||
print(f'[v8] tP_store.layout: {tP_store.layout}')
|
||||
print(f'[v8] tOrP0.layout: {tOrP0.layout}')
|
||||
print(f'[v8] tStS.layout: {tStS.layout}')
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── LOAD from C-layout (reading QK scores) ──
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# ── STORE to A-layout (writing P for PV MMA) ──
|
||||
# Use tP_store (A-fragment physical layout) as the store target
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tP_store)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tP_store)
|
||||
# Identity tensor for the A-layout shape
|
||||
cS_P = cute.make_identity_tensor((128, self.qk_mma_tiler[1])) # (M, K) for A-layout
|
||||
# Wait, A-layout is (M, K) but the partition is different...
|
||||
# make_fragment_A uses p_tmem_s which is an smem_layout_a — that's (M, K, L) shape
|
||||
# tP has shape ((128,16),1,4,1) — the 4 is MMA_K, 16 is K_inner
|
||||
# The identity tensor for partition_S should match tP's logical shape
|
||||
# But tP is a TMEM tensor, not SMEM. The partition_S for the store uses tP_store
|
||||
# as the destination. We need an identity tensor for the source.
|
||||
# Actually in fmha, partition_S uses tScS_P (composed from C-fragment identity).
|
||||
# Let me try the same approach: partition the store's source side with an
|
||||
# identity tensor that matches the LOAD output (C-layout).
|
||||
|
||||
# The store takes register data (from the load) and writes to TMEM (A-layout).
|
||||
# The register data has C-layout ordering (from the load).
|
||||
# The store target has A-layout addresses (tP_store).
|
||||
# We need the source partition of the store to match the register layout.
|
||||
# Since the register layout came from the load's destination (tTMEM_LOADcS),
|
||||
# and the store's source partition should match...
|
||||
|
||||
# Actually, the correct approach from fmha:
|
||||
# thr_store.partition_S maps the source (register) identity tensor
|
||||
# thr_store.partition_D maps the destination (TMEM) tensor
|
||||
# The copy operation copies from register[partition_S] to TMEM[partition_D]
|
||||
# The register layout must match what we computed (from the load)
|
||||
|
||||
# fmha uses tScS_P (composition of C-fragment identity) for partition_S
|
||||
# But if our store target is A-layout, the partition_S should use A-layout identity
|
||||
|
||||
# Let me just try using the A-layout identity tensor
|
||||
# tP_store has shape ((128,16),1,4,1) — need identity of (128, 64) for (M, K)
|
||||
# Actually tP's outer shape is ((128,16),1,4,1):((64,1),0,16,0)
|
||||
# The logical (M, K) identity: cS_A = make_identity_tensor((128, 64))
|
||||
cS_A = cute.make_identity_tensor((128, 64))
|
||||
tScS_A = tiled_tmem_store.get_slice(sfw_idx).partition_S(cS_A)
|
||||
|
||||
# Wait for scores
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# Load from C-layout
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Identity: F32 -> BF16
|
||||
tTMEM_STORErP = cute.make_rmem_tensor(tScS_A.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErP_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErP.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
s_vec = tTMEM_LOADrS.load()
|
||||
tTMEM_STORErP_e.store(s_vec.to(self.q_dtype))
|
||||
|
||||
# Store to A-layout TMEM
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErP, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmaxV8(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v8: (Q @ K^T) @ V with identity softmax (A-layout store)')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
354
tests/test_stage_b_v8b.py
Normal file
354
tests/test_stage_b_v8b.py
Normal file
@@ -0,0 +1,354 @@
|
||||
"""
|
||||
Stage B v8b: BF16 store directly to tOrP0 (diagnostic)
|
||||
|
||||
Key fixes over v6:
|
||||
- TMEM offsets computed via find_tmem_tensor_col_offset (same API as get_num_tmem_alloc_cols)
|
||||
- P tensor constructed from p_tmem_s.outer (matching fmha.py pattern exactly)
|
||||
- tilePlikeFP32 computed from qk_mma_tiler and dtype widths
|
||||
- tmem_alloc_cols from get_num_tmem_alloc_cols with all fragments
|
||||
- JIT-time diagnostic prints of all TMEM sizes
|
||||
|
||||
Architecture (matches fmha.py exactly):
|
||||
MMA1: Q @ K^T → tmem_scores (a_source=SMEM, accumulate=False)
|
||||
Identity softmax: tcgen05.ld C-layout → F32→BF16 → tcgen05.st A-layout
|
||||
MMA2: P @ V → tmem_output (a_source=TMEM, accumulate=True)
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
|
||||
class StageBIdentitySoftmax:
|
||||
def __init__(self, mma_tiler_mn, use_2cta_instrs=False, use_tma_store=True):
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.use_2cta_instrs = use_2cta_instrs; self.use_tma_store = use_tma_store
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.epilog_sync_bar_id = 1; self.tmem_alloc_sync_bar_id = 2; self.tmem_dealloc_sync_bar_id = 3
|
||||
self.num_c_stage = 2
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1],
|
||||
self.qk_mma_tiler[2],
|
||||
)
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.a_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.b_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.b_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
|
||||
# ── COMPUTE TMEM OFFSETS USING find_tmem_tensor_col_offset ──
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
# tilePlikeFP32 for the store-side composition
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
|
||||
# ── TMEM LAYOUT (matching fmha.py) ──
|
||||
# P OVERLAPS S — after softmax, P (A-layout) is written on top of scores (C-layout)
|
||||
# in the same TMEM region. The A-layout view starts partway through the S region.
|
||||
# fmha.py: S0=0, P0=32, O0=256 (with S1=128, P1=160 double-buffered)
|
||||
# The P offset of 32 means the A-layout starts at column 32 within the S region.
|
||||
# This is because the C-layout and A-layout partition TMEM differently per-thread;
|
||||
# the first 32 C-layout columns are "dead space" in the A-layout mapping.
|
||||
#
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32 # Same as fmha.py
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
# Also compute via get_num_tmem_alloc_cols for the full allocation
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
print(f"[StageB] s_cols (QK accumulator) = {s_cols}")
|
||||
print(f"[StageB] o_cols (PV accumulator) = {o_cols}")
|
||||
print(f"[StageB] tilePlikeFP32 = {self.tilePlikeFP32}")
|
||||
print(f"[StageB] tmem_s0_offset = {self.tmem_s0_offset}")
|
||||
print(f"[StageB] tmem_p0_offset = {self.tmem_p0_offset}")
|
||||
print(f"[StageB] tmem_o0_offset = {self.tmem_o0_offset}")
|
||||
print(f"[StageB] tmem_alloc_cols (computed) = {self.tmem_alloc_cols}")
|
||||
print(f"[StageB] num_tmem_alloc_cols (via utils) = {self.num_tmem_alloc_cols}")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.a_dtype, a_smem) + cute.size_in_bytes(self.b_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
self.a_dtype = a.element_type; self.b_dtype = b.element_type; self.c_dtype = c.element_type
|
||||
self.a_major = LayoutEnum.from_tensor(a).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(b).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, self.a_major, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.a_dtype, self.b_dtype, cute.nvgpu.OperandMajorMode.K, self.b_major,
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=self.tmem_alloc_sync_bar_id,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.a_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.b_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
# V shares the same SMEM as B (same data, different layout for PV MMA)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV) # V fragment from V SMEM layout
|
||||
|
||||
# ── TMEM tensors with computed offsets (matching fmha.py pattern) ──
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# P fragment: construct from p_tmem_s layout (matching fmha.py exactly)
|
||||
# fmha.py: tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer)
|
||||
# tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0]
|
||||
# tOrP0 = cute.make_tensor(tOrP.iterator + dtype_width_ratio * tmem_p0_offset, tOrP.layout)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None, None, None, 0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# ── DIAGNOSTIC: BF16 store directly to tOrP0 ──
|
||||
# Skip C→A transform entirely. Write known BF16 values to the exact
|
||||
# TMEM addresses MMA2 will read as P (the A-fragment).
|
||||
# This tests: does writing to the A-fragment's TMEM addresses work?
|
||||
|
||||
# Create a BF16 store atom matching the A-fragment type
|
||||
tmem_store_atom_bf16 = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype)
|
||||
tiled_tmem_store_bf16 = tcgen05.make_tmem_copy(tmem_store_atom_bf16, tOrP0)
|
||||
thr_store_bf16 = tiled_tmem_store_bf16.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store_bf16.partition_D(tOrP0)
|
||||
|
||||
# Wait for MMA1 scores (MMA1 must complete before we touch TMEM)
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# Fill register buffer with 1.0 in BF16, store to A-fragment TMEM
|
||||
tTMEM_STORErP = cute.make_rmem_tensor(tTMEM_STOREtP.shape, self.q_dtype)
|
||||
for j in cutlass.range(cute.size(tTMEM_STORErP), unroll_full=True):
|
||||
tTMEM_STORErP[j] = self.q_dtype(1.0)
|
||||
|
||||
cute.copy(tiled_tmem_store_bf16, tTMEM_STORErP, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Release back to MMA warp
|
||||
si_handle.release()
|
||||
|
||||
# ── Epilogue ──
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = torch.ones(128, 128, dtype=torch.float32, device="cuda") @ kvf # P=all-ones @ V
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBIdentitySoftmax(mma_tiler_mn=(128, 128), use_2cta_instrs=False, use_tma_store=True)
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
max_err = (out - ref).abs().max().item()
|
||||
print('Stage B v8b: BF16 store to tOrP0, P=all-ones')
|
||||
print(' Cosine: {:.6f}, Max error: {:.6f}'.format(cos, max_err))
|
||||
print(' {}'.format('PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
348
tests/test_stage_b_v9.py
Normal file
348
tests/test_stage_b_v9.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""
|
||||
Stage B v9: Identity softmax matching fmha.py softmax_step EXACTLY.
|
||||
|
||||
Line by line match of the fmha softmax_step, but with identity softmax:
|
||||
- No masking
|
||||
- scale = 1 (no log2 scaling)
|
||||
- row_max = 0 (skip max, just do exp(x) = 1 for identity)
|
||||
- Actually for identity softmax, P = S (scores). So we just ld S, st as P.
|
||||
- But we need to go through the C→A layout transform properly.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StageBv9:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2
|
||||
self.acc_dtype = Float32; self.epilog_sync_bar_id = 1
|
||||
self.use_2cta_instrs = False
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = 128
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (*self.mma_tiler_mn, pv_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.c_layout = c_layout
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
self.tmem_o0_offset = s_cols # 128
|
||||
self.tmem_alloc_cols = s_cols + o_cols # 256
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, 1))
|
||||
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
cute.nvgpu.OperandMajorMode.K,
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, pv_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(
|
||||
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sV_ptr = cute.recast_ptr(sB.iterator, v_smem_s.inner)
|
||||
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_acc_shape = pv_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
# ── TMEM copy atoms (matching fmha softmax_step exactly) ──
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_tmem_load = tiled_tmem_load.get_slice(tidx % (32 * len(self.epilogue_warp_id)))
|
||||
tTMEM_LOADtS = thr_tmem_load.partition_S(tStS0)
|
||||
|
||||
# Store target: composition of C-fragment (matching fmha)
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset,
|
||||
cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32))))
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_tmem_store = tiled_tmem_store.get_slice(tidx % (32 * len(self.epilogue_warp_id)))
|
||||
tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P)
|
||||
|
||||
# Identity tensors
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tScS_P = cute.make_tensor(tScS.iterator,
|
||||
cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32))))
|
||||
|
||||
tTMEM_LOADcS = thr_tmem_load.partition_D(tScS)
|
||||
tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P)
|
||||
|
||||
print(f'[v9] tTMEM_LOADcS.shape: {tTMEM_LOADcS.shape}')
|
||||
print(f'[v9] tTMEM_STOREcS.shape: {tTMEM_STOREcS.shape}')
|
||||
print(f'[v9] LOAD size: {cute.size(tTMEM_LOADcS)}')
|
||||
print(f'[v9] STORE size: {cute.size(tTMEM_STOREcS)}')
|
||||
print(f'[v9] tilePlikeFP32: {self.tilePlikeFP32}')
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
tCrV_s = tCrV[(None, None, None, 0)]
|
||||
nblk_pv = cute.size(tOrP0, mode=[2])
|
||||
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX / EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# Load scores (matching fmha exactly)
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# Identity softmax: P = S (no transform, just copy C→A)
|
||||
# Match fmha pattern: x4 register, BF16 view, .load()/.store()
|
||||
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout,
|
||||
)
|
||||
|
||||
# For identity softmax: exp(0) = 1 for all, so P[i] = exp(S[i] - max) = 1
|
||||
# But identity means P = S. So just copy: FP32→BF16, no math.
|
||||
# Use the fmha fragment pattern
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# Epilogue
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T @ kvf
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StageBv9(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Stage B v9 (fmha-matched identity softmax): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
285
tests/test_store_verify.py
Normal file
285
tests/test_store_verify.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Test: Q@K^T → TMEM (scores), ld scores, st to P region,
|
||||
then epilogue reads P region as C-fragment (not PV MMA).
|
||||
|
||||
If the ld/st roundtrip preserves the data, the epilogue should
|
||||
output the same as Stage A (Q@K^T result).
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StoreVerify:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
self.tmem_s0_offset = 0
|
||||
self.tmem_p0_offset = 32
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.tilePlikeFP32 = self.qk_mma_tiler[1] * self.q_dtype.width // 32
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR
|
||||
self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = s_cols # Only need scores region, no O region
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
|
||||
# TMEM copy atoms for ld/st
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_tmem_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_tmem_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_tmem_load.partition_D(tScS)
|
||||
|
||||
# Store target: P region (composition of C-fragment, offset 32)
|
||||
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset,
|
||||
cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32))))
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
|
||||
thr_tmem_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_tmem_store.partition_D(tStS_P)
|
||||
tScS_P = cute.make_tensor(tScS.iterator,
|
||||
cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32))))
|
||||
tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX WARPS: ld from S, st to P, then epilogue reads P ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# Load scores from S region
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# Identity: copy FP32→BF16, matching fmha pattern
|
||||
tTMEM_STORErP_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErP_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErP_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErP_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErP_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErP_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErP_x4, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# Epilogue: read from P region (offset 32) instead of S region (offset 0)
|
||||
# This tests if the store wrote to the correct physical TMEM locations
|
||||
# that the C-fragment epilogue can read back
|
||||
tCtP_base = cute.make_tensor(tmem_ptr + self.tmem_p0_offset, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtP_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T # Just Q@K^T — the ld/st roundtrip should produce this
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StoreVerify(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Store verify (ld S → st P → epilogue P): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
272
tests/test_store_verify2.py
Normal file
272
tests/test_store_verify2.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""Store verify v2: ld S (full 128x128), st P (full 128x128 at offset 128),
|
||||
then epilogue reads P region. No subview, no composition."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class StoreVerify2:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = s_cols * 2 # Two full regions
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) # offset 0
|
||||
tStS1 = cute.make_tensor(tStS.iterator + s_cols, tStS.layout) # offset 128 (second region)
|
||||
|
||||
# Load from S0, Store to S1 (same layout, just different offset)
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS1)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS1 = thr_store.partition_D(tStS1)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── SOFTMAX WARPS: ld from S0, st to S1, epilogue reads S1 ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# Load from S0
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
|
||||
# Identity: FP32→BF16→FP32 via recast
|
||||
tTMEM_STORErS1_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
tTMEM_STORErS1_x4_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS1_x4.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS1_x4_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS1_x4_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS1_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
# Store to S1
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS1_x4, tTMEM_STOREtS1)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# Epilogue reads from S1 (offset 128)
|
||||
tCtS1_base = cute.make_tensor(tmem_ptr + s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = StoreVerify2(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Store verify v2 (ld S0, st S1, epi S1): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
271
tests/test_tmem_copy_roundtrip.py
Normal file
271
tests/test_tmem_copy_roundtrip.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Minimal TMEM ld→st roundtrip. FP32 only, no BF16 cast.
|
||||
Uses the fmha pattern: load to register, store via make_rmem_tensor + recast + load/store."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class TMEMCopyRoundtrip:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols * 2
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
|
||||
|
||||
# Copy atoms
|
||||
tmem_load_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS0 = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
tmem_store_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS1)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtS1 = thr_store.partition_D(tStS1)
|
||||
tTMEM_STOREcS = thr_store.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ── TMA WARP ──
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# ── MMA WARP ──
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# ── EPILOGUE WARPS ──
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# Load from S0
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS0, tTMEM_LOADrS)
|
||||
|
||||
# Store to S1 using fmha recast pattern (but FP32→FP32, so identity)
|
||||
tTMEM_STORErS1 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
|
||||
# BF16 view of store register, with load layout (fp32→bf16 packing)
|
||||
tTMEM_STORErS1_e = cute.make_tensor(
|
||||
cute.recast_ptr(tTMEM_STORErS1.iterator, dtype=self.q_dtype),
|
||||
tTMEM_LOADrS.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
tTMEM_STORErS1_e_frg = cute.logical_divide(
|
||||
tTMEM_STORErS1_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
tTMEM_STORErS1_e_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, tTMEM_STORErS1, tTMEM_STOREtS1)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# Epilogue reads from S1
|
||||
tCtS1_base = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1_base, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref = qf @ kvf.T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = TMEMCopyRoundtrip(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('TMEM copy roundtrip (FP32 ld→BF16 recast→st→epi): cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
255
tests/test_tmem_fp32_roundtrip.py
Normal file
255
tests/test_tmem_fp32_roundtrip.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Minimal: Q@K^T → TMEM, ld FP32 from S0, st FP32 to S1, epi reads S1.
|
||||
NO bf16 cast at all. Pure FP32 ld→st roundtrip."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class FP32Roundtrip:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols * 2
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
|
||||
|
||||
# Single tiled copy that can both ld and st
|
||||
tmem_ld_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld_atom, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st_atom, tStS1)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
|
||||
tTMEM_LDtS = thr_ld.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LDcS = thr_ld.partition_D(tScS)
|
||||
tTMEM_STtS = thr_st.partition_D(tStS1)
|
||||
tTMEM_STcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# TMA
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
# MMA
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
# EPILOGUE WARPS
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld → register → FP32 st (NO bf16)
|
||||
rS = cute.make_rmem_tensor(tTMEM_LDcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tTMEM_LDtS, rS)
|
||||
|
||||
# FP32 → FP32, same identity tensor, same layout
|
||||
# Use the fmha x4 pattern but with FP32→BF16→FP32 packing
|
||||
rS1 = cute.make_rmem_tensor(tTMEM_STcS.shape, self.qk_acc_dtype)
|
||||
rS1_e = cute.make_tensor(cute.recast_ptr(rS1.iterator, dtype=self.q_dtype), rS.layout)
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(rS) // frg_cnt
|
||||
rS_frg = cute.logical_divide(rS, cute.make_layout(frg_tile))
|
||||
rS1_e_frg = cute.logical_divide(rS1_e, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
v = rS_frg[None, j].load()
|
||||
rS1_e_frg[None, j].store(v.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_st, rS1, tTMEM_STtS)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# epi reads S1
|
||||
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = FP32Roundtrip(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('FP32 ld→BF16 recast→st roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
168
tests/test_tmem_layout_diag.py
Normal file
168
tests/test_tmem_layout_diag.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Diagnostic: understand the C-layout vs A-layout TMEM address mapping.
|
||||
|
||||
After Stage A writes Q@K^T results to TMEM via C-fragment (tStS0),
|
||||
read the same TMEM data back using the A-fragment (tOrP0) layout
|
||||
and compare against reading it via the C-fragment layout.
|
||||
|
||||
This will tell us whether the composition-based store target
|
||||
produces the right physical TMEM addresses for the A-fragment to read.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
qf = q[:,:,0].float(); kvf = kv[:,:,0].float()
|
||||
ref_qk = qf @ kvf.T # (128, 128)
|
||||
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
|
||||
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
|
||||
c_layout = LayoutEnum.from_tensor(mC)
|
||||
mma_tiler_mn = (128, 128)
|
||||
mma_tiler = (*mma_tiler_mn, 1)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, a_major, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, b_major,
|
||||
Float32, tcgen05.CtaGroup.ONE, mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
||||
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
qk_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
|
||||
pv_mma_tiler = (*mma_tiler_mn, qk_inst_k * 4)
|
||||
|
||||
# Build the fragments
|
||||
qk_thr = qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(mma_tiler_mn)
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP = pv_thr.make_fragment_A(tP)[None, None, None, 0]
|
||||
|
||||
tilePlikeFP32 = qk_mma_tiler[1] * BFloat16.width // 32
|
||||
tStS_P = cute.make_tensor(tStS.iterator + 32, cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))))
|
||||
|
||||
# Print the key layouts
|
||||
print(f'tStS.layout: {tStS.layout}')
|
||||
print(f'tStS.size: {cute.size(tStS)}')
|
||||
print(f'tOrP.layout: {tOrP.layout}')
|
||||
print(f'tOrP.size: {cute.size(tOrP)}')
|
||||
print(f'tStS_P.layout: {tStS_P.layout}')
|
||||
print(f'tP.layout: {tP.layout}')
|
||||
print(f'tilePlikeFP32: {tilePlikeFP32}')
|
||||
|
||||
# For the C-fragment layout ((128,128),1,1):((65536,1),0,0)
|
||||
# element at logical (m, n) maps to address 65536*m + n
|
||||
# For the A-fragment layout ((128,16),1,4):((64,1),0,16)
|
||||
# element at logical ((m_inner, k_inner), 1, mma_k) maps to address 64*m_inner + k_inner + 16*mma_k
|
||||
|
||||
# In the C-fragment, logical (m, col) = (m, col) with stride (65536, 1)
|
||||
# In the A-fragment, logical ((m, k16), 1, block4) with stride ((64, 1), 0, 16)
|
||||
# So A[m, k16, block4] -> address 64*m + k16 + 16*block4
|
||||
# C[m, col] -> address 65536*m + col
|
||||
|
||||
# The KEY question: for the same physical TMEM column c and row m,
|
||||
# what does C-layout say the logical index is, vs A-layout?
|
||||
|
||||
# C-layout: (m, col) -> addr = 65536*m + col
|
||||
# A-layout: ((m, k16), 1, blk) -> addr = 64*m + k16 + 16*blk
|
||||
|
||||
# If both refer to the same physical TMEM, then:
|
||||
# 65536*m_c + col = 64*m_a + k16 + 16*blk
|
||||
# These can only be equal if the address spaces are different (C uses a different base)
|
||||
# OR if the C-layout addresses wrap around modulo the TMEM size
|
||||
|
||||
# C-fragment cosize = 8323200 which is >> 16384 (actual data size)
|
||||
# The C-fragment uses a STRIDED layout where stride-0 (M) = 65536
|
||||
# This means the C-fragment is NOT a dense layout in TMEM address space
|
||||
# The actual TMEM addresses used by the MMA hardware follow a different mapping
|
||||
|
||||
# Let's check: for M=128, N=128 C-fragment with ((128,128),1,1):((65536,1),0,0)
|
||||
# The max address = 65536*127 + 127 = 8,321,023
|
||||
# But TMEM only has 1MB = 256 columns * 4096 rows
|
||||
# So addresses 65536*m must be modded or the layout is virtual
|
||||
|
||||
# AH. The C-fragment layout is VIRTUAL. cute.compile maps it to physical TMEM
|
||||
# addresses when the actual MMA operation runs. The strides in the fragment layout
|
||||
# are logical, not physical. The MMA hardware knows the physical column mapping.
|
||||
|
||||
# For the store, make_tmem_copy(St32x32bOp, tStS_P) creates a copy plan.
|
||||
# The copy writes to tStS_P's layout addresses. These ARE the physical TMEM addresses
|
||||
# that the MMA hardware will read from for the A-fragment.
|
||||
|
||||
# So the question becomes: is the COMPOSITION tStS_P correct?
|
||||
# tStS.layout = ((128,128),1,1):((65536,1),0,0)
|
||||
# composition with (128, 64) produces (128,64):(65536,1)
|
||||
# This takes the first 64 columns of the C-layout.
|
||||
|
||||
# tOrP.layout = ((128,16),1,4):((64,1),0,16)
|
||||
# This is the A-layout for 4 K=16 blocks.
|
||||
|
||||
# Let me compute what TMEM addresses each layout generates for the same (m, k) pair.
|
||||
# For m in [0..127], for k in [0..63]:
|
||||
# C-layout: logical (m, k) -> addr = 65536*m + k
|
||||
# A-layout: k = 16*blk + k_inner
|
||||
# logical ((m, k_inner), 1, blk) -> addr = 64*m + k_inner + 16*blk = 64*m + k
|
||||
|
||||
# So C-layout says addr = 65536*m + k
|
||||
# A-layout says addr = 64*m + k
|
||||
# These are DIFFERENT for m > 0.
|
||||
|
||||
# The C-fragment stride of 65536 is NOT the physical TMEM stride.
|
||||
# Physical TMEM for M=128 has stride 64 (one 32B row per datapath, 64 FP32 columns)
|
||||
# Wait, 128 BF16 = 256 bytes = 8 columns of 32B? No...
|
||||
|
||||
# TMEM is organized as columns of 128 rows of 32 bits each.
|
||||
# For FP32 accumulator: each column holds 128 FP32 values (one per DP lane)
|
||||
# For BF16 view: each FP32 column = 2 BF16 values packed? Or separate columns?
|
||||
|
||||
# The 64 stride in A-layout: 128 BF16 values need 64 FP32 columns (2 BF16 per FP32)
|
||||
# stride-1 in the (128,16) sub-shape: (64,1) means m has stride 64, k_inner has stride 1
|
||||
# 64 FP32 columns for 128 BF16 rows makes sense: each column holds 2 BF16, so 128/2=64 columns
|
||||
|
||||
# But wait, TMEM is 32-bit per entry per DP lane. For FP32, one column = 128 FP32 values.
|
||||
# For BF16 A-fragment, each column stores 2 BF16 (high+low), so 128 BF16 = 64 columns.
|
||||
# That matches the stride of 64 for m in the A-layout.
|
||||
|
||||
# For the C-fragment with FP32 accumulator, each column = 128 FP32 values.
|
||||
# The C-layout for (128,128) FP32: stride (65536, 1)
|
||||
# 65536 = 512 * 128. That's weird.
|
||||
# Actually 65536 = 2^16. For 128 rows in FP32, each row in a column needs... hmm.
|
||||
# The C-fragment stores FP32 values. stride-0 = 65536 means elements in the M direction
|
||||
# are 65536 addresses apart. But we have 128*128 = 16384 elements total.
|
||||
# cosize = 8323200 >> 16384. The layout is very sparse.
|
||||
|
||||
# I think the C-fragment layout addresses are NOT physical TMEM addresses.
|
||||
# The MMA hardware and the ld/st copy atoms know how to map between logical and physical.
|
||||
# The copy atom's partition_S/partition_D functions create the right physical mappings.
|
||||
|
||||
# The real question for Stage B is simpler:
|
||||
# After Q@K^T writes to tStS0 (C-fragment), and we ld the scores,
|
||||
# apply identity softmax, and st them to tStS_P (composed C-fragment)...
|
||||
# does the PV MMA read the same data from tOrP0 (A-fragment)?
|
||||
|
||||
# fmha.py proves this works for F16. The composition pattern is identical.
|
||||
# So either: (1) something about BF16 vs F16, or (2) something else in my code is wrong.
|
||||
|
||||
print('\nDiagnostic complete. Need to check if BF16 vs F16 affects the TMEM layout mapping.')
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
237
tests/test_tmem_pure_fp32.py
Normal file
237
tests/test_tmem_pure_fp32.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Absolute minimal: ld FP32 from S0, st FP32 to S1, epi reads S1.
|
||||
No recast, no BF16, no packing. Pure FP32 copy between TMEM regions."""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
|
||||
class PureFP32Copy:
|
||||
def __init__(self, mma_tiler_mn):
|
||||
self.qk_acc_dtype = Float32; self.q_dtype = BFloat16; self.o_dtype = BFloat16
|
||||
self.c_dtype = BFloat16; self.acc_dtype = Float32
|
||||
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0, 1, 2, 3)
|
||||
self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192
|
||||
self.num_c_stage = 2; self.use_2cta_instrs = False
|
||||
self.epilog_sync_bar_id = 1
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
||||
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
|
||||
c_layout = LayoutEnum.ROW_MAJOR; self.c_layout = c_layout
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
|
||||
self.cta_tile_shape_mnk, False, c_layout, self.o_dtype)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, c_layout, self.epi_tile, 2)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
self.s_cols = find_tmem_tensor_col_offset(tStS)
|
||||
self.tmem_alloc_cols = self.s_cols * 2
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
self.num_tma_load_bytes = (
|
||||
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem)
|
||||
) * cute.size(qk_mma.thr_id.shape)
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, a: cute.Tensor, b: cute.Tensor, c: cute.Tensor, stream: cuda.CUstream):
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
self.q_dtype, self.q_dtype,
|
||||
LayoutEnum.from_tensor(a).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(b).mma_major_mode(),
|
||||
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn,
|
||||
tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
|
||||
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
|
||||
tma_a, tma_ta = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
a, a_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
tma_b, tma_tb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
||||
b, b_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
||||
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
||||
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
||||
self._kernel(qk_mma, tma_a, tma_ta, tma_b, tma_tb, tma_c, tma_tc,
|
||||
self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.c_smem_s, self.epi_tile
|
||||
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_a, mA, tma_b, mB, tma_c, mC, cl_vmnk,
|
||||
a_smem_s, b_smem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx, _, _ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_a); cpasync.prefetch_descriptor(tma_b); cpasync.prefetch_descriptor(tma_c)
|
||||
@cute.struct
|
||||
class SS:
|
||||
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
|
||||
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
||||
tmem_dealloc: cutlass.Int64
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
|
||||
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
||||
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True
|
||||
).make_participants()
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
||||
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
||||
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
||||
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, len(self.epilogue_warp_id)),
|
||||
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
|
||||
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
||||
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=False,
|
||||
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
||||
sA = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
|
||||
sB = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
||||
gA = cute.local_tile(mA, cute.slice_(self.mma_tiler, (None,0,None)), (None,None,None))
|
||||
gB = cute.local_tile(mB, cute.slice_(self.mma_tiler, (0,None,None)), (None,None,None))
|
||||
gC = cute.local_tile(mC, cute.slice_(self.mma_tiler, (None,None,0)), (None,None,None))
|
||||
k_cnt = cute.size(gA, mode=[3])
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgA = qk_thr.partition_A(gA); tCgB = qk_thr.partition_B(gB); tCgC = qk_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
|
||||
tAsA, tAgA = cpasync.tma_partition(tma_a, 0, a_lay, cute.group_modes(sA,0,3), cute.group_modes(tCgA,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
|
||||
tBsB, tBgB = cpasync.tma_partition(tma_b, 0, b_lay, cute.group_modes(sB,0,3), cute.group_modes(tCgB,0,3))
|
||||
tAgA = tAgA[(None,0,None,0)]; tBgB = tBgB[(None,0,None,0)]
|
||||
tCrA = qk_mma.make_fragment_A(sA); tCrB = qk_mma.make_fragment_B(sB)
|
||||
qk_acc_shape = qk_thr.partition_shape_C(self.mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
tStS1 = cute.make_tensor(tStS.iterator + self.s_cols, tStS.layout)
|
||||
|
||||
# LD and ST on same layout
|
||||
tmem_ld = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tmem_st = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_ld = tcgen05.make_tmem_copy(tmem_ld, tStS0)
|
||||
tiled_st = tcgen05.make_tmem_copy(tmem_st, tStS1)
|
||||
sfw = tidx % (32 * len(self.epilogue_warp_id))
|
||||
thr_ld = tiled_ld.get_slice(sfw)
|
||||
thr_st = tiled_st.get_slice(sfw)
|
||||
tLdS = thr_ld.partition_S(tStS0)
|
||||
tStS = thr_st.partition_D(tStS1)
|
||||
cS_id = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS_id)
|
||||
tLdcS = thr_ld.partition_D(tScS)
|
||||
tStcS = thr_st.partition_S(tScS)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, 1))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
if warp_idx == self.tma_warp_id:
|
||||
ab_p.reset(); peek = ab_p.try_acquire()
|
||||
for kt in cutlass.range(k_cnt, unroll=1):
|
||||
h = ab_p.acquire_and_advance(peek)
|
||||
cute.copy(tma_a, tAgA[(None,h.count)], tAsA[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
cute.copy(tma_b, tBgB[(None,h.count)], tBsB[(None,h.index)], tma_bar_ptr=h.barrier)
|
||||
peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_p.try_acquire()
|
||||
ab_p.tail()
|
||||
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
ab_c.reset(); peek = ab_c.try_wait()
|
||||
s0_handle = mma_si_prod.acquire_and_advance()
|
||||
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_prod_st)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kt in range(k_cnt):
|
||||
h = ab_c.wait_and_advance(peek)
|
||||
nblk = cute.size(tCrA, mode=[2])
|
||||
for kb in cutlass.range(nblk, unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrA[(None,None,kb,h.index)], tCrB[(None,None,kb,h.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
h.release(); peek = cutlass.Boolean(1)
|
||||
if h.count+1<k_cnt: peek = ab_c.try_wait()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
s0_handle.commit()
|
||||
acc_pipe.producer_commit(acc_prod_st)
|
||||
acc_prod_st.advance()
|
||||
acc_pipe.producer_tail(acc_prod_st)
|
||||
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
|
||||
si_handle = mma_si_cons.wait_and_advance()
|
||||
|
||||
# FP32 ld → FP32 st, NO recast
|
||||
rLd = cute.make_rmem_tensor(tLdcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_ld, tLdS, rLd)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Direct copy: ld register → st register (same shape since same layout)
|
||||
rSt = cute.make_rmem_tensor(tStcS.shape, self.qk_acc_dtype)
|
||||
# Since ld and st have the same C-fragment layout and same identity tensor,
|
||||
# the register shapes should match. Copy element by element.
|
||||
for i in cutlass.range(cute.size(rLd), vectorize=True):
|
||||
rSt[i] = rLd[i]
|
||||
|
||||
cute.copy(tiled_st, rSt, tStS)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
si_handle.release()
|
||||
|
||||
# epi reads S1
|
||||
tCtS1 = cute.make_tensor(tmem_ptr + self.s_cols, tCtS_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtS1, sC, tCgC,
|
||||
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
|
||||
c_pipe.producer_tail()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
m, n, k = 128, 128, 128
|
||||
q = torch.randn(m, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
kv = torch.randn(n, k, 1, dtype=torch.bfloat16, device='cuda')
|
||||
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
|
||||
ref = q[:,:,0].float() @ kv[:,:,0].float().T
|
||||
import cutlass.torch as ct
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(kv).mark_layout_dynamic(leading_dim=ct.get_leading_dim(kv))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
kernel = PureFP32Copy(mma_tiler_mn=(128, 128))
|
||||
print('Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mC, stream)
|
||||
print('Running...', flush=True)
|
||||
compiled(mQ, mK, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
out = c[:,:,0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
|
||||
print('Pure FP32 ld→st copy roundtrip: cos={:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user