From a3559538cf09565e4162f23ff67177db3861803d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 23:43:23 +0000 Subject: [PATCH] D2: try 6-mode coordinate for flat_divide result --- dsv4/kernels/attention/fmha.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index bd2afa7d..dc71a968 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -195,17 +195,17 @@ class FmhaKernel: # Q: if num_ctas > 1, mQ has a head dimension. Use flat_divide for runtime coordinate. # K/V: shared (MQA), always coordinate 0. if const_expr(self.num_ctas > 1): - # flat_divide creates a tiled view where the tiler dims become inner modes - # and the rest dims become outer modes. Then we index the outer modes with block_idx. + # flat_divide creates (tile_M, tile_K, rest...) layout. + # For mQ=(n_h, T, hd, 1) with tiler=(128, hd): + # tC_gQ shape = (128, hd, n_h, T/128, 1, 1) + # head_cta_idx indexes mode 2 (n_h) q_tiler = cute.slice_(self.qk_mma_tiler, (None, 0, None)) # (128, hd) tC_gQ = cute.flat_divide(mQ, q_tiler) print(f"D2 DEBUG: tC_gQ shape={cute.shape(tC_gQ)}") - # tC_gQ has modes: (128, hd, rest...). Rest modes include n_h and trailing dims. - # Need to figure out correct coordinate for head indexing. - gQ = tC_gQ[None, None, head_cta_idx, None] # try different coordinate order + gQ = tC_gQ[None, None, head_cta_idx, None, None, None] c_tiler = cute.slice_(self.pv_mma_tiler, (None, None, 0)) # (128, hd) tC_gC = cute.flat_divide(mC, c_tiler) - gC = tC_gC[None, None, head_cta_idx, None] + gC = tC_gC[None, None, head_cta_idx, None, None, None] else: gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))