FIX: (None,0,None,0) for ALL tma_partition outputs — verified shapes on B200
DIAG OUTPUT (n=256, inside @cute.kernel): tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes tVgV: (((64,128),1), 1, 1, 1) — 4 modes After (None,0,None,0) → keeps modes 0 and 2 free → 2D: tAgQ: (((64,128),1), Int32(?)) tBgK: (((64,128),1), Int32(?)) tVgV: (((64,128),1), 1) Then [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles). tAgQ[(None, Int32(0))] for Q (1 tile, coordinate is always 0). Removed diag prints from test_fmha_v3.py.
This commit is contained in:
@@ -14,14 +14,14 @@ Three structural rules learned the hard way:
|
||||
`utils.sm100.get_tmem_load_op` + `get_smem_store_op` works and is what
|
||||
the CUTLASS Blackwell FMHA reference uses in `correction_rescale`.
|
||||
|
||||
(C) tma_partition produces a 4-mode tensor: (V_grouped, ?, KV_tiles, ?).
|
||||
Mode 2 is the GMEM-tile iteration axis. Pre-slicing with
|
||||
`tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets mode 2 to 0,
|
||||
collapsing the KV-tile axis so TMA always reads tile 0.
|
||||
(C) tma_partition produces 4-mode tensors: (((64,128),1), ?, ?, ?).
|
||||
Mode 2 is the GMEM-tile iteration axis.
|
||||
Pre-slicing with `tBgK[(None,None,0,0)]` keeps modes 0,1 free but sets
|
||||
mode 2 to 0, collapsing the KV-tile axis so TMA always reads tile 0.
|
||||
Fix: pre-slice with `tBgK[(None,0,None,0)]` to keep modes 0,2 free,
|
||||
then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes mode 1 (the
|
||||
surviving KV_tiles mode from the original mode 2) with kt.
|
||||
This matches the CUTLASS reference FMHA pattern.
|
||||
then `cute.copy(tma_k, tBgK[None, kt], ...)` indexes the surviving
|
||||
KV_tiles mode. Verified shapes on B200: after (None,0,None,0), all
|
||||
three tensors become 2D: (((64,128),1), Int32(?)) or (((64,128),1), 1).
|
||||
|
||||
Kernel structure:
|
||||
|
||||
@@ -181,14 +181,15 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
# NOTE: after tma_partition, tAgQ has 2 modes (Q has 1 tile),
|
||||
# tBgK/tVgV have 4 modes: (V_grouped, ?, KV_tiles, ?).
|
||||
# Mode 2 of tBgK/tVgV is the GMEM tile iteration axis.
|
||||
# The old pre-slice (None,None,0,0) kept modes 0,1 free and set
|
||||
# mode 2 to 0, collapsing the KV tile axis.
|
||||
# Fix: (None,0,None,0) keeps modes 0,2 free, then [None, kt]
|
||||
# indexes the surviving KV_tiles dim in cute.copy.
|
||||
tAgQ = tAgQ[(None,0)]
|
||||
# SHAPES (from diag test on B200, n=256):
|
||||
# tAgQ: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
|
||||
# tBgK: (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
|
||||
# tVgV: (((64,128),1), 1, 1, 1) — 4 modes
|
||||
# Mode 2 is the GMEM tile iteration axis.
|
||||
# (None,0,None,0) keeps modes 0 and 2 free → 2D tensor.
|
||||
# Then [None, kt] indexes the surviving KV_tiles dim.
|
||||
# The old (None,None,0,0) kept modes 0,1 free → mode 2 (KV tiles) set to 0.
|
||||
tAgQ = tAgQ[(None,0,None,0)]
|
||||
tBgK = tBgK[(None,0,None,0)]
|
||||
tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
@@ -220,7 +221,7 @@ class FmhaV3StageCMulti:
|
||||
# Matches the CUTLASS reference FMHA TMA indexing pattern.
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[0], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
cute.copy(tma_q, tAgQ[None, Int32(0)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
qp.tail()
|
||||
kvp.reset()
|
||||
for kt in cutlass.range_constexpr(self.n_kv_tiles):
|
||||
|
||||
@@ -134,16 +134,7 @@ class FmhaV3:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
print(f"DIAG tAgQ: shape={cute.shape(tAgQ)} stride={tAgQ.layout.stride}")
|
||||
print(f"DIAG tBgK: shape={cute.shape(tBgK)} stride={tBgK.layout.stride}")
|
||||
print(f"DIAG tVgV: shape={cute.shape(tVgV)} stride={tVgV.layout.stride}")
|
||||
print(f"DIAG tAsQ: shape={cute.shape(tAsQ)} stride={tAsQ.layout.stride}")
|
||||
print(f"DIAG tBsK: shape={cute.shape(tBsK)} stride={tBsK.layout.stride}")
|
||||
print(f"DIAG tVsV: shape={cute.shape(tVsV)} stride={tVsV.layout.stride}")
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
print(f"DIAG tAgQ sliced (None,0,None,0): shape={cute.shape(tAgQ)} stride={tAgQ.layout.stride}")
|
||||
print(f"DIAG tBgK sliced (None,0,None,0): shape={cute.shape(tBgK)} stride={tBgK.layout.stride}")
|
||||
print(f"DIAG tVgV sliced (None,0,None,0): shape={cute.shape(tVgV)} stride={tVgV.layout.stride}")
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
|
||||
Reference in New Issue
Block a user