restore tBgK to kh.count indexing (single-tile working), add TODO for multi-tile
CuTeDSL TMA copy API doesn't support dynamic GMEM tile indexing. kh.count works for single tile. For multi-tile, need to either: 1. Map pipeline count to tile index (kh.count // 2 for interleaved K/V) 2. Separate K and V into non-interleaved TMA loops 3. Use gK/gV layouts that iterate naturally with pipeline count This is the architectural blocker for multi-tile FMHA.
This commit is contained in:
@@ -214,11 +214,15 @@ class FmhaV3StageCMulti:
|
||||
for kt in cutlass.range(n_kv_tiles, unroll=1):
|
||||
kh = kvp.acquire_and_advance(pk)
|
||||
# GMEM tile: kt (correct K[kt]). SMEM slot: kh.index (ring buffer).
|
||||
cute.copy(tma_k, tBgK[(None, kt, None, 0)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
|
||||
# TODO: tBgK[(None, kt)] indexes SMEM dim, NOT GMEM tile (the slice
|
||||
# (None,0,None,0) fixed gmem_iter to 0). CuTeDSL doesn't support
|
||||
# dynamic gmem tile indexing via subscript. Need pipeline count mapping
|
||||
# or separate K/V TMA loops. For now, kh.count works for single tile.
|
||||
cute.copy(tma_k, tBgK[(None, kh.count)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
vh = kvp.acquire_and_advance(pk)
|
||||
# GMEM tile: kt (correct V[kt]). SMEM slot: vh.index (ring buffer).
|
||||
cute.copy(tma_v, tVgV[(None, kt, None, 0)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, vh.count)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user