diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index bd2afa7d..dc71a968 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -195,17 +195,17 @@ class FmhaKernel: # Q: if num_ctas > 1, mQ has a head dimension. Use flat_divide for runtime coordinate. # K/V: shared (MQA), always coordinate 0. if const_expr(self.num_ctas > 1): - # flat_divide creates a tiled view where the tiler dims become inner modes - # and the rest dims become outer modes. Then we index the outer modes with block_idx. + # flat_divide creates (tile_M, tile_K, rest...) layout. + # For mQ=(n_h, T, hd, 1) with tiler=(128, hd): + # tC_gQ shape = (128, hd, n_h, T/128, 1, 1) + # head_cta_idx indexes mode 2 (n_h) q_tiler = cute.slice_(self.qk_mma_tiler, (None, 0, None)) # (128, hd) tC_gQ = cute.flat_divide(mQ, q_tiler) print(f"D2 DEBUG: tC_gQ shape={cute.shape(tC_gQ)}") - # tC_gQ has modes: (128, hd, rest...). Rest modes include n_h and trailing dims. - # Need to figure out correct coordinate for head indexing. - gQ = tC_gQ[None, None, head_cta_idx, None] # try different coordinate order + gQ = tC_gQ[None, None, head_cta_idx, None, None, None] c_tiler = cute.slice_(self.pv_mma_tiler, (None, None, 0)) # (128, hd) tC_gC = cute.flat_divide(mC, c_tiler) - gC = tC_gC[None, None, head_cta_idx, None] + gC = tC_gC[None, None, head_cta_idx, None, None, None] else: gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))