FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing
The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0. Instead, keep the original tBgK and index with (None, kt, None, 0) inside the TMA loop, where kt selects the correct GMEM tile. This preserves 2D rank matching with the SMEM tensor.
This commit is contained in:
@@ -177,7 +177,7 @@ class FmhaV3StageCMulti:
|
||||
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,None,None,0)]; tVgV = tVgV[(None,None,None,0)]
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
tCrV = pv_mma.make_fragment_B(sV)
|
||||
@@ -214,11 +214,11 @@ class FmhaV3StageCMulti:
|
||||
for kt in cutlass.range(n_kv_tiles, unroll=1):
|
||||
kh = kvp.acquire_and_advance(pk)
|
||||
# GMEM tile: kt (correct K[kt]). SMEM slot: kh.index (ring buffer).
|
||||
cute.copy(tma_k, tBgK[(None, kt, None)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
|
||||
cute.copy(tma_k, tBgK[(None, kt, None, 0)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
vh = kvp.acquire_and_advance(pk)
|
||||
# GMEM tile: kt (correct V[kt]). SMEM slot: vh.index (ring buffer).
|
||||
cute.copy(tma_v, tVgV[(None, kt, None)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
|
||||
cute.copy(tma_v, tVgV[(None, kt, None, 0)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
|
||||
pk = cutlass.Boolean(1)
|
||||
kvp.tail()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user