D2: fix LayoutEnum for multi-dim Q (use head-0 view for layout)

This commit is contained in:
2026-05-24 23:33:27 +00:00
parent 2b76b691cb
commit 49c4189195

View File

@@ -107,7 +107,16 @@ class FmhaKernel:
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
# For multi-CTA, q has shape (n_h, T, hd, 1). Layout depends on inner (T, hd) dims only.
# Create a view of q without the head dimension for layout computation.
if const_expr(self.num_ctas > 1):
q_inner = cute.local_tile(q, cute.make_layout(1, stride=1), (0,)) # select head 0's view
self.a_major = LayoutEnum.from_tensor(q_inner).mma_major_mode()
c_inner = cute.local_tile(c, cute.make_layout(1, stride=1), (0,))
self.c_layout = LayoutEnum.from_tensor(c_inner)
else:
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,