D2: use flat_divide for runtime coordinate indexing (like CUTLASS)
This commit is contained in:
@@ -191,22 +191,23 @@ class FmhaKernel:
|
||||
|
||||
# Print mQ shape at trace time to understand mode structure
|
||||
print(f"D2 DEBUG: mQ shape={cute.shape(mQ)}, mK shape={cute.shape(mK)}")
|
||||
# mQ has 4 modes: (n_h, T, hd, 1). Tiler covers (T, hd) = modes 1,2.
|
||||
# Rest modes are (n_h, 1) = modes 0,3. Coordinate is (head_idx, 0).
|
||||
|
||||
# Q: if num_ctas > 1, mQ has a head dimension. local_tile indexes into it.
|
||||
# Q: if num_ctas > 1, mQ has a head dimension. Use flat_divide for runtime coordinate.
|
||||
# K/V: shared (MQA), always coordinate 0.
|
||||
# For single-CTA (num_ctas=1), head_cta_idx=0 and coordinates are the same as before.
|
||||
if const_expr(self.num_ctas > 1):
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(head_cta_idx,0))
|
||||
# flat_divide creates a tiled view where the tiler dims become inner modes
|
||||
# and the rest dims become outer modes. Then we index the outer modes with block_idx.
|
||||
q_tiler = cute.slice_(self.qk_mma_tiler, (None, 0, None)) # (128, hd)
|
||||
tC_gQ = cute.flat_divide(mQ, q_tiler)
|
||||
gQ = tC_gQ[head_cta_idx, None, None] # index head dim, take first M/K tile
|
||||
c_tiler = cute.slice_(self.pv_mma_tiler, (None, None, 0)) # (128, hd)
|
||||
tC_gC = cute.flat_divide(mC, c_tiler)
|
||||
gC = tC_gC[head_cta_idx, None, None]
|
||||
else:
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
|
||||
if const_expr(self.num_ctas > 1):
|
||||
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(head_cta_idx,0))
|
||||
else:
|
||||
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
|
||||
n_kv_tiles = cute.size(gK, mode=[3])
|
||||
|
||||
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
|
||||
|
||||
Reference in New Issue
Block a user