D2: try 6-mode coordinate for flat_divide result

This commit is contained in:
2026-05-24 23:43:23 +00:00
parent 6f371d6b31
commit a3559538cf

View File

@@ -195,17 +195,17 @@ class FmhaKernel:
# Q: if num_ctas > 1, mQ has a head dimension. Use flat_divide for runtime coordinate.
# K/V: shared (MQA), always coordinate 0.
if const_expr(self.num_ctas > 1):
# flat_divide creates a tiled view where the tiler dims become inner modes
# and the rest dims become outer modes. Then we index the outer modes with block_idx.
# flat_divide creates (tile_M, tile_K, rest...) layout.
# For mQ=(n_h, T, hd, 1) with tiler=(128, hd):
# tC_gQ shape = (128, hd, n_h, T/128, 1, 1)
# head_cta_idx indexes mode 2 (n_h)
q_tiler = cute.slice_(self.qk_mma_tiler, (None, 0, None)) # (128, hd)
tC_gQ = cute.flat_divide(mQ, q_tiler)
print(f"D2 DEBUG: tC_gQ shape={cute.shape(tC_gQ)}")
# tC_gQ has modes: (128, hd, rest...). Rest modes include n_h and trailing dims.
# Need to figure out correct coordinate for head indexing.
gQ = tC_gQ[None, None, head_cta_idx, None] # try different coordinate order
gQ = tC_gQ[None, None, head_cta_idx, None, None, None]
c_tiler = cute.slice_(self.pv_mma_tiler, (None, None, 0)) # (128, hd)
tC_gC = cute.flat_divide(mC, c_tiler)
gC = tC_gC[None, None, head_cta_idx, None]
gC = tC_gC[None, None, head_cta_idx, None, None, None]
else:
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))