FIX: (None,0,None,0) pre-slice keeps KV tile axis (mode 2) free
tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?). Mode 2 is the GMEM tile dim. Old (None,None,0,0) kept modes 0,1 free → mode 2 collapsed to 0 → always tile 0. 8-None no-op slice FAILS — tensor is 4-mode, not 8-mode, at JIT level. Fix: (None,0,None,0) keeps modes 0,2 free → 2D tensor. Then tBgK[None, kt] indexes the surviving KV_tiles dim. Matches CUTLASS reference FMHA pattern: tKgK = tKgK_kdl[None, None, 0, batch] cute.copy(tma_k, tKgK[None, kv_coord], ...)
This commit is contained in:
@@ -14,18 +14,14 @@ Three structural rules learned the hard way:
|
||||
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
|
||||
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
|
||||
|
||||
(C) tma_partition produces a tensor with 8 TMA coordinate dimensions, but only
|
||||
4 are visible in the Python shape. After
|
||||
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay,
|
||||
group_modes(sK,0,3),
|
||||
group_modes(tCgK,0,3))
|
||||
`tBgK` has 8 TMA coord modes: (1,1,1,1,n_kv_tiles,1,1,1).
|
||||
Mode 4 is the GMEM-tile iteration axis.
|
||||
Pre-slicing with `tBgK[(None,None,0,0)]` collapses the GMEM tile axis to
|
||||
coordinate 0, so every TMA copy reads tile 0 regardless of the coord
|
||||
value passed. The 8-None no-op pre-slice opens the full TMA coord space.
|
||||
Fix: tBgK = tBgK[(None,None,None,None,None,None,None,None)], then
|
||||
cute.copy(tma_k, tBgK[None,None,None,None,kt,None,None,None], ...)
|
||||
(C) tma_partition produces a 4-mode tensor: (V_grouped, ?, KV_tiles, ?).
|
||||
Mode 2 is the GMEM-tile iteration axis. Pre-slicing with
|
||||
`tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets mode 2 to 0,
|
||||
collapsing the KV-tile axis so TMA always reads tile 0.
|
||||
Fix: pre-slice with `tBgK[(None,0,None,0)]` to keep modes 0,2 free,
|
||||
then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes mode 1 (the
|
||||
surviving KV_tiles mode from the original mode 2) with kt.
|
||||
This matches the CUTLASS reference FMHA pattern.
|
||||
|
||||
Kernel structure:
|
||||
|
||||
@@ -185,15 +181,17 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# NOTE: after tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
|
||||
# Shape is (1,1,1,1,n_kv_tiles,1,1,1) in TMA coord space.
|
||||
# Mode 4 is the GMEM tile-iteration axis.
|
||||
# The no-op 8-None slice opens up the full TMA coordinate space.
|
||||
# Without it, only 4 modes are visible and 8-mode indexing fails.
|
||||
# The old (None,None,0,0) pre-slice collapsed the GMEM tile axis to 0.
|
||||
# NOTE: after tma_partition, tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?).
|
||||
# Mode 2 is the GMEM tile iteration axis (size = n_kv_tiles).
|
||||
# The old pre-slice (None,None,0,0) kept modes 0,1 free and set mode 2 to 0,
|
||||
# collapsing the KV tile axis. Fix: keep modes 0,2 free with (None,0,None,0),
|
||||
# then index with [None, kt] in cute.copy. This matches the CUTLASS reference
|
||||
# pattern: 2D pre-slice, 2-mode indexing in copy.
|
||||
# tVgV similarly has KV tiles at mode 2.
|
||||
# tAgQ: (None,0,None,0) works because Q has only 1 tile.
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
|
||||
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
|
||||
tBgK = tBgK[(None,0,None,0)]
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -217,13 +215,10 @@ class FmhaV3StageCMulti:
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp — fully unrolled =====
|
||||
# The old pre-slice tBgK[(None,None,0,0)] collapsed the 8-mode TMA
|
||||
# coordinate space to 4 modes, pinning mode 4 (GMEM tile dim) to 0.
|
||||
# The 8-None no-op pre-slice opens the full TMA coord space so we
|
||||
# can index mode 4 with kt in cute.copy. The pipeline's
|
||||
# acquire/release machinery still tracks the kv_stage ring buffer
|
||||
# dynamically at runtime, so the producer correctly blocks on
|
||||
# consumer release when n_kv_tiles > kv_stage.
|
||||
# The old pre-slice (None,None,0,0) kept modes 0,1 free and set
|
||||
# mode 2 (KV tiles) to 0. Fix: (None,0,None,0) keeps modes 0,2 free.
|
||||
# Then [None, kt] indexes the KV tile axis correctly.
|
||||
# Matches the CUTLASS reference FMHA TMA indexing pattern.
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, 0, 0, 0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
@@ -231,18 +226,17 @@ class FmhaV3StageCMulti:
|
||||
kvp.reset()
|
||||
for kt in cutlass.range_constexpr(self.n_kv_tiles):
|
||||
kvh = kvp.acquire_and_advance()
|
||||
# 8-mode TMA indexing: mode 4 = GMEM tile axis (size n_kv_tiles).
|
||||
# The 8-None pre-slice above opened the full coord space.
|
||||
# kt at mode 4 indexes the correct KV tile in GMEM.
|
||||
# After (None,0,None,0) pre-slice, tBgK is 2D: (V_grouped, KV_tiles).
|
||||
# [None, kt] indexes mode 1 (KV_tiles) with kt.
|
||||
cute.copy(
|
||||
tma_k,
|
||||
tBgK[None, None, None, None, kt, None, None, None],
|
||||
tBgK[None, kt],
|
||||
tBsK[(None, kvh.index)],
|
||||
tma_bar_ptr=kvh.barrier,
|
||||
)
|
||||
cute.copy(
|
||||
tma_v,
|
||||
tVgV[None, None, None, None, kt, None, None, None],
|
||||
tVgV[None, kt],
|
||||
tVsV[(None, kvh.index)],
|
||||
tma_bar_ptr=kvh.barrier,
|
||||
)
|
||||
|
||||
@@ -140,13 +140,13 @@ class FmhaV3Diag:
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# =====================================================================
|
||||
# ⚠️⚠️⚠️ CRITICAL: TMA PARTITION TENSOR MODE ORDERING ⚠️⚠️⚠️
|
||||
# After tma_partition, tBgK/tVgV have 8 modes: (1,1,1,1,n_kv_tiles,1,1,1)
|
||||
# After tma_partition, tBgK/tVgV have 4 modes: (V_grouped, ?, KV_tiles, ?)
|
||||
# Mode 4 is the GMEM tile dimension. DO NOT pre-slice to fewer modes!
|
||||
# See README.md for details.
|
||||
# =====================================================================
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
|
||||
tVgV = tVgV[(None,None,None,None,None,None,None,None)] # 8 modes! No pre-slice!
|
||||
tBgK = tBgK[(None,0,None,0)]
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -179,10 +179,10 @@ class FmhaV3Diag:
|
||||
kv_coord = Int32(0 + 0)
|
||||
for kt in cutlass.range(self.n_kv_tiles, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
# ⚠️ CRITICAL: kv_coord indexes MODE 4 of 8-mode tBgK/tVgV.
|
||||
# kv_coord indexes the surviving mode 1 (from original mode 2) of 2D tBgK/tVgV.
|
||||
# Using (None, kv_coord) on a pre-sliced 4-mode tensor SILENTLY BREAKS multi-tile!
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kv_coord, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[None, None, None, None, kv_coord, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[None, kv_coord], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[None, kv_coord], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
kv_coord += 1
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
@@ -179,13 +179,13 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# After tma_partition, tBgK/tVgV have 8 TMA coordinate dimensions.
|
||||
# After tma_partition, tBgK has 4 modes: (V_grouped, ?, KV_tiles, ?).
|
||||
# Mode 4 is the GMEM tile iteration axis (size = n_kv_tiles).
|
||||
# 8-None no-op pre-slice opens the full TMA coord space.
|
||||
# tAgQ is fine with 4-mode slice (Q has only 1 tile).
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,None,None,None,None,None,None,None)]
|
||||
tVgV = tVgV[(None,None,None,None,None,None,None,None)]
|
||||
tBgK = tBgK[(None,0,None,0)]
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -223,8 +223,8 @@ class FmhaV3StageCMulti:
|
||||
kvp.reset(); pk = kvp.try_acquire()
|
||||
for kt in cutlass.range(0, n_kv_tiles, 1, unroll=1):
|
||||
kvh = kvp.acquire_and_advance(pk)
|
||||
cute.copy(tma_k, tBgK[None, None, None, None, kt, None, None, None], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[None, None, None, None, kt, None, None, None], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_k, tBgK[None, kt], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
cute.copy(tma_v, tVgV[None, kt], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user