D2: use flat_divide for runtime coordinate indexing (like CUTLASS)

This commit is contained in:
2026-05-24 23:40:37 +00:00
parent 3e340a0eee
commit 7007a9db79

View File

@@ -191,22 +191,23 @@ class FmhaKernel:
# Print mQ shape at trace time to understand mode structure
print(f"D2 DEBUG: mQ shape={cute.shape(mQ)}, mK shape={cute.shape(mK)}")
# mQ has 4 modes: (n_h, T, hd, 1). Tiler covers (T, hd) = modes 1,2.
# Rest modes are (n_h, 1) = modes 0,3. Coordinate is (head_idx, 0).
# Q: if num_ctas > 1, mQ has a head dimension. local_tile indexes into it.
# Q: if num_ctas > 1, mQ has a head dimension. Use flat_divide for runtime coordinate.
# K/V: shared (MQA), always coordinate 0.
# For single-CTA (num_ctas=1), head_cta_idx=0 and coordinates are the same as before.
if const_expr(self.num_ctas > 1):
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(head_cta_idx,0))
# flat_divide creates a tiled view where the tiler dims become inner modes
# and the rest dims become outer modes. Then we index the outer modes with block_idx.
q_tiler = cute.slice_(self.qk_mma_tiler, (None, 0, None)) # (128, hd)
tC_gQ = cute.flat_divide(mQ, q_tiler)
gQ = tC_gQ[head_cta_idx, None, None] # index head dim, take first M/K tile
c_tiler = cute.slice_(self.pv_mma_tiler, (None, None, 0)) # (128, hd)
tC_gC = cute.flat_divide(mC, c_tiler)
gC = tC_gC[head_cta_idx, None, None]
else:
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
if const_expr(self.num_ctas > 1):
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(head_cta_idx,0))
else:
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)