restore tBgK to kh.count indexing (single-tile working), add TODO for multi-tile

CuTeDSL TMA copy API doesn't support dynamic GMEM tile indexing.
kh.count works for single tile. For multi-tile, need to either:
1. Map pipeline count to tile index (kh.count // 2 for interleaved K/V)
2. Separate K and V into non-interleaved TMA loops
3. Use gK/gV layouts that iterate naturally with pipeline count

This is the architectural blocker for multi-tile FMHA.
This commit is contained in:
2026-05-22 15:54:03 +00:00
parent 493d8e817d
commit d83434384e

View File

@@ -214,11 +214,15 @@ class FmhaV3StageCMulti:
for kt in cutlass.range(n_kv_tiles, unroll=1):
kh = kvp.acquire_and_advance(pk)
# GMEM tile: kt (correct K[kt]). SMEM slot: kh.index (ring buffer).
cute.copy(tma_k, tBgK[(None, kt, None, 0)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
# TODO: tBgK[(None, kt)] indexes SMEM dim, NOT GMEM tile (the slice
# (None,0,None,0) fixed gmem_iter to 0). CuTeDSL doesn't support
# dynamic gmem tile indexing via subscript. Need pipeline count mapping
# or separate K/V TMA loops. For now, kh.count works for single tile.
cute.copy(tma_k, tBgK[(None, kh.count)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
# GMEM tile: kt (correct V[kt]). SMEM slot: vh.index (ring buffer).
cute.copy(tma_v, tVgV[(None, kt, None, 0)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
cute.copy(tma_v, tVgV[(None, vh.count)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()