D2: try 6-mode coordinate for flat_divide result
This commit is contained in:
@@ -195,17 +195,17 @@ class FmhaKernel:
|
||||
# Q: if num_ctas > 1, mQ has a head dimension. Use flat_divide for runtime coordinate.
|
||||
# K/V: shared (MQA), always coordinate 0.
|
||||
if const_expr(self.num_ctas > 1):
|
||||
# flat_divide creates a tiled view where the tiler dims become inner modes
|
||||
# and the rest dims become outer modes. Then we index the outer modes with block_idx.
|
||||
# flat_divide creates (tile_M, tile_K, rest...) layout.
|
||||
# For mQ=(n_h, T, hd, 1) with tiler=(128, hd):
|
||||
# tC_gQ shape = (128, hd, n_h, T/128, 1, 1)
|
||||
# head_cta_idx indexes mode 2 (n_h)
|
||||
q_tiler = cute.slice_(self.qk_mma_tiler, (None, 0, None)) # (128, hd)
|
||||
tC_gQ = cute.flat_divide(mQ, q_tiler)
|
||||
print(f"D2 DEBUG: tC_gQ shape={cute.shape(tC_gQ)}")
|
||||
# tC_gQ has modes: (128, hd, rest...). Rest modes include n_h and trailing dims.
|
||||
# Need to figure out correct coordinate for head indexing.
|
||||
gQ = tC_gQ[None, None, head_cta_idx, None] # try different coordinate order
|
||||
gQ = tC_gQ[None, None, head_cta_idx, None, None, None]
|
||||
c_tiler = cute.slice_(self.pv_mma_tiler, (None, None, 0)) # (128, hd)
|
||||
tC_gC = cute.flat_divide(mC, c_tiler)
|
||||
gC = tC_gC[None, None, head_cta_idx, None]
|
||||
gC = tC_gC[None, None, head_cta_idx, None, None, None]
|
||||
else:
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
|
||||
|
||||
Reference in New Issue
Block a user