Fix README: multi-tile was layout bug not JIT bug, add example10, update status
This commit is contained in:
49
README.md
49
README.md
@@ -136,7 +136,7 @@ Summary
|
||||
|-------|--------|-------------|
|
||||
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
|
||||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||||
| C | ⚠️ SINGLE-TILE ONLY | Real online softmax works for n=128 (cosine 0.993-0.996). **Multi-tile (n>128) broken.** |
|
||||
| C | ⚠️ MULTI-TILE IN PROGRESS | Single-tile cos 0.999998. TMA fix: n=256 cos 0.9956. Need O rescale + pipeline cycling. |
|
||||
| C' | 🔨 IN PROGRESS | Multi-tile TMA indexing fix + correction warps. See below. |
|
||||
| D | TODO | Full decode attention: paged KV cache, multi-query, causal mask |
|
||||
| E | TODO | Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge |
|
||||
@@ -280,35 +280,48 @@ What it does:
|
||||
|
||||
---
|
||||
|
||||
## Stage C: Online Softmax — SINGLE-TILE ONLY
|
||||
## Stage C: Online Softmax — Multi-Tile In Progress
|
||||
|
||||
### What We Have
|
||||
|
||||
**Working real softmax** for single KV tile (n=128) in `test_fmha_v3_stage_c_full.py`: cosine 0.993-0.996.
|
||||
**Multi-tile (n>128) is broken** — see blocker below.
|
||||
**Working real softmax** for single KV tile (n=128): cosine 0.999998.
|
||||
**Multi-tile TMA indexing fixed** (n=256 cosine 0.9956) — was a layout bug, NOT a JIT bug.
|
||||
**Remaining:** O rescale between tiles, pipeline state cycling for n≥384, correction warps.
|
||||
|
||||
### Multi-Tile Blocker: TMA GMEM Tile Indexing
|
||||
### Multi-Tile TMA Fix (RESOLVED — was a LAYOUT bug, not a JIT bug)
|
||||
|
||||
The TMA partition slices `tBgK`/`tVgV` with `(None, 0, None, 0)`. The free mode after slicing is the GMEM iteration dimension. A `kv_coord` variable is used to index it. **Problem: the `kv_coord` increment is not propagating to the TMA at runtime.**
|
||||
After `cpasync.tma_partition()`, the output GMEM tensor has **8 modes**, not 4:
|
||||
|
||||
**Evidence (May 22):**
|
||||
- `kv_coord = Int32(0)` + `kv_coord += 1` in `cutlass.range` loop → all multi-tile outputs identical (TMA loads from tile 0 every iteration)
|
||||
- `kv_coord = 0` (plain Python int) + `kv_coord += 1` → same broken result
|
||||
- `kv_coord = Int32(1)` hardcoded → output **changes** (TMA CAN load from tile 1, the coordinate just isn't being dynamically updated)
|
||||
- Pipeline handle `.count` also doesn't work (it's opaque pipeline state, not a GMEM coordinate)
|
||||
```
|
||||
tBgK shape: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1)
|
||||
0 1 2 3 4 5 6 7
|
||||
```
|
||||
|
||||
**Root cause:** CuTeDSL's JIT appears to constant-fold or not propagate the `kv_coord += 1` increment to the TMA descriptor at runtime. The CUTLASS reference uses the same pattern with a Python int `kv_coord` — unclear why it works there but not here (possibly different CuTeDSL version or loop structure).
|
||||
**Mode 4 is the GMEM tile dimension.** Our old pre-slice `tBgK[(None, None, 0, 0)]` only addressed 4 modes — modes 4-7 were implicitly collapsed to coordinate 0, so TMA always read tile 0. The bug looked like "JIT constant-folding" but was purely a layout error.
|
||||
|
||||
**Debug shape info:**
|
||||
- `tBgK` before slice: `(((64, 128), 1), Int32(?), Int32(?), Int32(?))` — modes 1,2,3 all dynamic
|
||||
- `tVgV` before slice: `(((64, 128), 1), 1, N, 1)` — mode 2 grows with n (confirmed GMEM iter)
|
||||
- After `(None,0,None,0)`: both become `(((64, 128), 1), N_or_Int32(?))` — 2D
|
||||
**The fix:** Do not pre-slice. Index all 8 modes explicitly in `cute.copy`, putting `kt` at mode 4:
|
||||
|
||||
```python
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], ...)
|
||||
```
|
||||
|
||||
**Results after fix:**
|
||||
- n=128: cos 0.999998 ✅
|
||||
- n=256: cos 0.9956 ✅ (lower because no O rescale yet)
|
||||
|
||||
### Remaining for Multi-Tile
|
||||
|
||||
1. O rescale between tiles: `O *= exp2(old_max - new_max)` — needed for n=256+ to hit 0.9999
|
||||
2. Pipeline state cycling for n≥384 (3+ tiles with 2 pipeline stages)
|
||||
3. Correction warps for production (separate softmax/correction/epilogue)
|
||||
4. 12-warp layout
|
||||
|
||||
### Files
|
||||
|
||||
| File | Status | Notes |
|
||||
|------|--------|-------|
|
||||
| `test_fmha_v3_stage_c_full.py` | OK n=128 only | Working real softmax + O normalization |
|
||||
| `fmha_v3_stage_c_example10.py` | 🔨 CURRENT | 8-mode TMA, combined K+V pipeline, O rescale, final normalize |
|
||||
| `test_fmha_v3_stage_c_full.py` | OK n=128 | Working real softmax + O normalization |
|
||||
| `fmha_v3_stage_c_example1.py` | BROKEN multi-tile | First fix attempt, TMA still loads tile 0 |
|
||||
| `fmha_v3_stage_c_example2.py` | DEADLOCK | Combined K+V barrier, compiles but deadlocks |
|
||||
| `test_fmha_v3_stage_c2.py` | DEADLOCK | 12-warp pipeline, compiles but deadlocks |
|
||||
@@ -329,7 +342,7 @@ Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Ep
|
||||
1. `vectorize=True` loops: ONLY load/store/print
|
||||
2. `.reduce(cute.ReductionOp.MAX)`: reduces ENTIRE C-fragment to scalar — global max, not per-row
|
||||
3. `cute.arch.fmax`: impure for vectorizer — use plain `range()` loop
|
||||
4. TMA cute.copy accepts pipeline state values as coordinates but NOT Python int
|
||||
4. `tBgK`/`tVgV` have 8 modes after tma_partition — mode 4 is GMEM tile dim, must index all 8 explicitly
|
||||
5. `tBgK[(None, 0, None, 0)]` hardcodes GMEM iteration to tile 0
|
||||
6. `softmax_done_bar` NamedBarrier is reusable across tiles
|
||||
|
||||
|
||||
@@ -0,0 +1,529 @@
|
||||
"""
|
||||
FMHA v3 Stage-C Multi-Tile (8-mode TMA indexing, paired-atom epilogue).
|
||||
|
||||
Three structural rules learned the hard way:
|
||||
|
||||
(A) Pipeline handle's `.count` is NOT a GMEM tile coordinate. Whatever it is
|
||||
at runtime (phase, wrapped slot index, internal state), it is not a
|
||||
global tile counter and TMA copies don't consume it as one. Use the
|
||||
loop induction variable for GMEM, handle.index for SMEM.
|
||||
|
||||
(B) Hand-constructed TMEM load/store atoms (Ld32x32bOp + St32x32bOp built
|
||||
independently) preserve register tile shape across a round-trip only if
|
||||
they share the same Repetition count. Pair-matching also via
|
||||
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
|
||||
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
|
||||
|
||||
(C) tma_partition produces an 8-mode tensor, not a 4-mode one. After
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay,
|
||||
group_modes(sK,0,3),
|
||||
group_modes(tCgK,0,3))
|
||||
`tBgK` shape is (1, 1, 1, 1, n_kv_tiles, 1, 1, 1). Mode 4 is the
|
||||
GMEM-tile iteration axis. Pre-slicing with `tBgK[(None,None,0,0)]`
|
||||
addresses only 4 modes — the remaining modes (including the KV-tile
|
||||
axis at mode 4) get implicitly collapsed to coord 0, so every TMA copy
|
||||
reads tile 0 regardless of what's passed at index 1. The bug pretends
|
||||
to be a JIT issue: dynamic coords seem to be "constant-folded" because
|
||||
the only axis they could vary along has stride 0.
|
||||
|
||||
Fix: do not pre-slice. Index all 8 modes explicitly in the producer's
|
||||
`cute.copy`, putting `kt` at mode 4 and `None` (or 0) everywhere else.
|
||||
|
||||
Kernel structure:
|
||||
|
||||
1. Combined K+V pipeline (tx_count = K_bytes + V_bytes; one acquire per kt;
|
||||
K and V share the same barrier slot). SMEM slot via kvh.index, GMEM via
|
||||
the loop's Python int kt (producer is fully unrolled at trace time via
|
||||
cutlass.range_constexpr, since self.n_kv_tiles is known from __init__).
|
||||
|
||||
2. Reference-style scaled epilogue: TMEM correction_rescale (O *= 1/row_sum
|
||||
via paired Ld32x32b + St32x32b atoms), then standard epilogue_tma_store
|
||||
to send O from TMEM through SMEM to GMEM. No TMEM round-trip with
|
||||
mismatched atoms.
|
||||
|
||||
3. Per-tile O rescale (O *= exp2(old_max - new_max) before PV[kt]) lives in
|
||||
the softmax warp BEFORE softmax_done_bar.arrive(). Reuses the same
|
||||
paired-atom pattern as the final normalize.
|
||||
|
||||
4. final_o_bar (32 MMA + 128 softmax threads). MMA arrives between
|
||||
acc_pipe.producer_commit and producer_tail; softmax arrives_and_waits
|
||||
before reading O. Order: producer_commit → final_o_bar.arrive() →
|
||||
producer_tail (reverse deadlocks).
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
import math
|
||||
|
||||
HEAD_DIM = 64
|
||||
|
||||
|
||||
class FmhaV3StageCMulti:
|
||||
def __init__(self, s_k=128, scale_softmax=None):
|
||||
# s_k MUST equal actual sequence length n.
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
|
||||
self.threads_per_cta = 192; self.num_c_stage = 2
|
||||
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
|
||||
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(HEAD_DIM)
|
||||
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
|
||||
self.c_layout = LayoutEnum.ROW_MAJOR
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
|
||||
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
|
||||
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
|
||||
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
|
||||
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
||||
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
p_end = self.tmem_p0_offset + p_cols_fp32
|
||||
s_cols = self.qk_mma_tiler[1]
|
||||
o_after = max(s_cols, p_end)
|
||||
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
total = self.tmem_o0_offset + o_cols
|
||||
self.num_tmem_alloc_cols = 1
|
||||
while self.num_tmem_alloc_cols < total:
|
||||
self.num_tmem_alloc_cols *= 2
|
||||
cta = cute.size(qk_mma.thr_id.shape)
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
|
||||
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
||||
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
|
||||
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
|
||||
# Combined barrier: tx_count covers BOTH K and V transfers per acquire.
|
||||
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
|
||||
cute.size_in_bytes(self.q_dtype, v_s)) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
v_fmha = cute.make_tensor(
|
||||
v.iterator,
|
||||
cute.make_layout(
|
||||
(HEAD_DIM, self.s_k, 1),
|
||||
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
|
||||
),
|
||||
)
|
||||
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
|
||||
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
|
||||
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
|
||||
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
|
||||
epi_s = cute.select(self.c_smem_s,mode=[0,1])
|
||||
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
|
||||
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
|
||||
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
||||
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
|
||||
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
|
||||
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
|
||||
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
|
||||
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
|
||||
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
|
||||
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
|
||||
n_kv_tiles = cute.size(gK, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
|
||||
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
|
||||
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
|
||||
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# NOTE: after tma_partition, ALL three tensors (tAgQ, tBgK, tVgV)
|
||||
# have 8 modes, e.g. tBgK shape = (1, 1, 1, 1, n_kv_tiles, 1, 1, 1).
|
||||
# Mode 4 is the GMEM tile-iteration axis; all other modes are size 1.
|
||||
# We previously pre-sliced like `tBgK[(None,None,0,0)]` which only
|
||||
# addressed 4 modes — modes 4..7 got swept up into the trailing 0
|
||||
# and the KV-tile axis was effectively collapsed to tile 0 always.
|
||||
# That, not a JIT bug, was why every dynamic coord produced tile-0
|
||||
# data. With no pre-slice, we index all 8 modes explicitly in the
|
||||
# producer warp below, putting `kt` at mode 4 and `None` elsewhere.
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
||||
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
|
||||
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp — fully unrolled =====
|
||||
# Why fully unrolled: original suspicion was a JIT bug propagating
|
||||
# dynamic coords; actual root cause was a layout bug — the pre-slice
|
||||
# collapsed all the mode-4 GMEM-tile axis to 0. After tma_partition,
|
||||
# tBgK and tVgV have 8 modes: (1, 1, 1, 1, n_kv_tiles, 1, 1, 1) where
|
||||
# mode 4 is the KV-tile iteration axis.
|
||||
#
|
||||
# Indexing rule: pass `None` for every mode that's size 1 (passthrough),
|
||||
# and `kt` for the KV-tile axis at mode 4.
|
||||
#
|
||||
# We keep the unroll for correctness confidence. The pipeline's
|
||||
# acquire/release machinery still tracks the kv_stage ring buffer
|
||||
# dynamically at runtime, so the producer correctly blocks on
|
||||
# consumer release when n_kv_tiles > kv_stage. The unroll only
|
||||
# flattens the LOOP control flow, not the synchronization.
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
# Q's 4-mode indexing was confirmed working at n=128 cos 0.999998
|
||||
# before the multi-tile investigation; leave it alone. Only K/V's
|
||||
# tma_partition output is 8 modes (the KV-tile axis introduces
|
||||
# the extra modes — Q has no equivalent tile axis since there's
|
||||
# one Q tile per CTA).
|
||||
cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset()
|
||||
for kt in cutlass.range_constexpr(self.n_kv_tiles):
|
||||
kvh = kvp.acquire_and_advance()
|
||||
# 8-mode indices, mode 4 = KV-tile axis (size n_kv_tiles).
|
||||
# Every other mode is size 1, so None is fine.
|
||||
cute.copy(
|
||||
tma_k,
|
||||
tBgK[None, None, None, None, kt, None, None, None],
|
||||
tBsK[(None, kvh.index)],
|
||||
tma_bar_ptr=kvh.barrier,
|
||||
)
|
||||
cute.copy(
|
||||
tma_v,
|
||||
tVgV[None, None, None, None, kt, None, None, None],
|
||||
tVsV[(None, kvh.index)],
|
||||
tma_bar_ptr=kvh.barrier,
|
||||
)
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
# Outer kt loop unrolled to match the producer. The earlier hypothesis
|
||||
# was that CuTeDSL 4.5.1 couldn't propagate dynamic TMA coords; the
|
||||
# actual root cause turned out to be the producer's GMEM-tensor
|
||||
# pre-slice eating the mode-4 KV-tile axis. The unroll is kept here
|
||||
# for symmetry with the producer (single concern: one fewer thing
|
||||
# that could surprise us if the layout assumption shifts), but the
|
||||
# GMEM indexing fix is what actually makes multi-tile work.
|
||||
# Inner GEMM K-block loops stay dynamic. kvh.index correctly tracks
|
||||
# the SMEM ring buffer at runtime.
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
qc.reset(); qh = qc.wait_and_advance(); qh.release()
|
||||
kvc.reset()
|
||||
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
||||
acc_pipe.producer_acquire(acc_st)
|
||||
for kt in cutlass.range_constexpr(self.n_kv_tiles):
|
||||
kvh = kvc.wait_and_advance()
|
||||
sh = s_prod.acquire_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
sh.commit()
|
||||
softmax_done_bar.arrive_and_wait()
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
kvh.release()
|
||||
acc_pipe.producer_commit(acc_st); acc_st.advance()
|
||||
final_o_bar.arrive()
|
||||
acc_pipe.producer_tail(acc_st)
|
||||
|
||||
# ===== SOFTMAX + EPILOGUE warps =====
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
||||
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
|
||||
|
||||
# S load
|
||||
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS0)
|
||||
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
# P store
|
||||
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
|
||||
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
|
||||
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
|
||||
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
|
||||
thr_store = tiled_tmem_store.get_slice(sfw_idx)
|
||||
tTMEM_STOREtP = thr_store.partition_D(tStP0)
|
||||
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# === O rescale path setup (used per-tile AND for final normalize) ===
|
||||
corr_tile_size = 16
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
|
||||
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
tmem_load_o_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
||||
self.acc_dtype,
|
||||
)
|
||||
tmem_store_o_atom = cute.make_copy_atom(
|
||||
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
|
||||
self.acc_dtype,
|
||||
)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
|
||||
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
n_corr_tiles = HEAD_DIM // corr_tile_size
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
|
||||
# Per-tile softmax loop with online rescale.
|
||||
# Unrolled for consistency with producer/MMA warps. The `if kt > 0`
|
||||
# rescale guard now becomes a Python-level conditional at trace
|
||||
# time (no rescale block emitted for kt=0; rescale block emitted
|
||||
# in-line for kt=1..N-1).
|
||||
for kt in cutlass.range_constexpr(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
|
||||
# Load S[kt]
|
||||
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
|
||||
# Pass 1: update row_max in log2-domain.
|
||||
old_row_max = row_max
|
||||
frg_cnt = 4
|
||||
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
|
||||
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
|
||||
|
||||
row_max_safe = row_max
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
row_max_safe = Float32(0.0)
|
||||
|
||||
# acc_scale = exp2(old_max - new_max). On first tile this is 0
|
||||
# (old_max = -inf), so row_sum stays 0 and rescale is skipped.
|
||||
# row_max is already in scaled domain, so no extra scale_log2.
|
||||
acc_scale_ = old_row_max - row_max_safe
|
||||
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
if old_row_max == -cutlass.Float32.inf:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
|
||||
# cast to BF16 via FP32-backed register bridge.
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
|
||||
minus_row_max = Float32(0.0) - row_max_safe
|
||||
|
||||
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
|
||||
for j in range(frg_cnt):
|
||||
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
|
||||
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
|
||||
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
|
||||
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
|
||||
s_vec = tTMEM_LOADrS_frg[None, j].load()
|
||||
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
|
||||
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# === Per-tile O rescale: O *= acc_scale for kt > 0 ===
|
||||
# Uses the SAME paired-atom pattern as the final normalize.
|
||||
# Must run BEFORE softmax_done_bar.arrive() so MMA's PV[kt]
|
||||
# reads the rescaled O.
|
||||
# Visibility of MMA's PV[kt-1] writes: provided by
|
||||
# s_cons.wait_and_advance at the top of this iteration, which
|
||||
# acquires on MMA's S[kt] commit. S[kt] is sequenced after
|
||||
# PV[kt-1] in MMA's iteration, so PV[kt-1]'s tmem_store_fence
|
||||
# has been observed by the time we read O here.
|
||||
if kt > 0:
|
||||
for i in range(n_corr_tiles):
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[k] = tTMrO[k] * acc_scale
|
||||
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# Wait for MMA's PV[N-1] to commit before reading O for normalize.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
for i in range(n_corr_tiles):
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO)
|
||||
cute.arch.fence_view_async_tmem_load()
|
||||
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
|
||||
tTMrO[k] = tTMrO[k] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Standard epilogue: TMEM → SMEM → GMEM via TMA store.
|
||||
# O in TMEM is now scaled by 1/row_sum.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
for n in [128, 256, 512, 1024]:
|
||||
torch.manual_seed(42)
|
||||
m, hd = 128, HEAD_DIM
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v_kernel = v.unsqueeze(-1)
|
||||
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn = qf @ kf.T * scale
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
ref = attn @ v.float()
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = FmhaV3StageCMulti(s_k=n)
|
||||
print(f'n={n}: Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
print(f'n={n}: tmem s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} '
|
||||
f'o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols} '
|
||||
f'kv_tx_bytes={kernel.kv_tx_bytes}', flush=True)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:, :, 0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
max_abs = (out - ref).abs().max().item()
|
||||
n_tiles = n // 128
|
||||
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): '
|
||||
f'cos {cos:.6f} max_abs {max_abs:.4f} '
|
||||
f'{"PASS" if cos >= 0.99 else "FAIL"}')
|
||||
if cos < 0.99:
|
||||
print(f' out[0,:4]={out[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user