From f80f8eb38fa3c8d373c0ae42219c600fd0e95708 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 17:39:27 +0000 Subject: [PATCH] Clean up debug prints, set kv_coord as Int32(0) Key findings to relay to CUTLASS LLM: - kv_coord=Int32(1) hardcode CHANGES the output (TMA CAN load from different tiles) - kv_coord=Int32(0) + kv_coord += 1 does NOT increment at runtime (all multi-tile outputs identical to kv_coord=0) - kv_coord=0 (plain Python int) also doesn't work - Pipeline handle .count doesn't work either - The TMA GMEM tile coordinate must be dynamic at kernel runtime, but CuTeDSL appears to constant-fold or not propagate the increment --- tests/unit/test_fmha_v3_stage_c_full.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 93bd5bbe..aeee8739 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -167,17 +167,12 @@ class FmhaV3StageC: b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape) tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3)) - # GMEM slices: keep the GMEM iteration mode free for kv_coord indexing. - # CUTLASS reference: tKgK = tKgK_kdl[None, None, 0, 0] (keeps TMA_atom + GMEM_iter) - # tVgV = tVgV_dkl[None, 0, None, 0] (keeps TMA_atom + GMEM_iter at mode 2) - # SMEM tensors from tma_partition are already 2D — don't re-slice them. - print(f"DEBUG tBgK shape before slice: {tBgK.shape}") - print(f"DEBUG tVgV shape before slice: {tVgV.shape}") + # GMEM slices: (None,0,None,0) keeps mode 2 free for both K and V. + # Debug shapes showed tBgK modes 1,2 are Int32(?), tVgV mode 2 grows with n. + # Both K and V have GMEM iteration at mode 2 in our tma_partition output. tAgQ = tAgQ[(None,0,None,0)] # Q: 1 tile, hardcode is fine - tBgK = tBgK[(None,0,None,0)] # K: try keeping mode 2 free (like V) - tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 (GMEM iter) free - print(f"DEBUG tBgK shape after slice: {tBgK.shape}") - print(f"DEBUG tVgV shape after slice: {tVgV.shape}") + tBgK = tBgK[(None,0,None,0)] # K: keep mode 2 free (GMEM iter) + tVgV = tVgV[(None,0,None,0)] # V: keep mode 2 free (GMEM iter) tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) @@ -209,7 +204,7 @@ class FmhaV3StageC: cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier) qp.tail() kvp.reset(); pk = kvp.try_acquire() - kv_coord = 0 # Plain Python int, like CUTLASS reference + kv_coord = Int32(0) # MUST be Int32 for TMA addressing for kt in cutlass.range(n_kv_tiles, unroll=1): kvh = kvp.acquire_and_advance(pk) cute.copy(tma_k, tBgK[(None, kv_coord)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)