diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index e9f4b700..4f2cde17 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -191,22 +191,23 @@ class FmhaKernel: # Print mQ shape at trace time to understand mode structure print(f"D2 DEBUG: mQ shape={cute.shape(mQ)}, mK shape={cute.shape(mK)}") - # mQ has 4 modes: (n_h, T, hd, 1). Tiler covers (T, hd) = modes 1,2. - # Rest modes are (n_h, 1) = modes 0,3. Coordinate is (head_idx, 0). - # Q: if num_ctas > 1, mQ has a head dimension. local_tile indexes into it. + # Q: if num_ctas > 1, mQ has a head dimension. Use flat_divide for runtime coordinate. # K/V: shared (MQA), always coordinate 0. - # For single-CTA (num_ctas=1), head_cta_idx=0 and coordinates are the same as before. if const_expr(self.num_ctas > 1): - gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(head_cta_idx,0)) + # flat_divide creates a tiled view where the tiler dims become inner modes + # and the rest dims become outer modes. Then we index the outer modes with block_idx. + q_tiler = cute.slice_(self.qk_mma_tiler, (None, 0, None)) # (128, hd) + tC_gQ = cute.flat_divide(mQ, q_tiler) + gQ = tC_gQ[head_cta_idx, None, None] # index head dim, take first M/K tile + c_tiler = cute.slice_(self.pv_mma_tiler, (None, None, 0)) # (128, hd) + tC_gC = cute.flat_divide(mC, c_tiler) + gC = tC_gC[head_cta_idx, None, None] else: gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) + gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None)) - if const_expr(self.num_ctas > 1): - gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(head_cta_idx,0)) - else: - gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None)) n_kv_tiles = cute.size(gK, mode=[3]) qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)