D2: hardcode a_major=MN for multi-CTA (Q is always MN-major in FMHA)

This commit is contained in:
2026-05-24 23:35:49 +00:00
parent e809e71253
commit df3146eb53

View File

@@ -108,14 +108,13 @@ class FmhaKernel:
def __call__(self, q, k, v, c, stream, lse=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
# For multi-CTA, q has shape (n_h, T, hd, 1). Layout depends on inner (T, hd) dims only.
# LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, we extract the
# inner dimensions by slicing away the head dim.
# LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, compute layout manually.
if const_expr(self.num_ctas > 1):
# q shape: (n_h, T, hd, 1). Slice head dim to get (T, hd, 1) view.
q_inner = q[0] # select head 0 → (T, hd, 1)
self.a_major = LayoutEnum.from_tensor(q_inner).mma_major_mode()
c_inner = c[0] # select head 0 → (T, hd, 1)
self.c_layout = LayoutEnum.from_tensor(c_inner)
# q shape: (n_h, T, hd, 1). The inner (T, hd, 1) dims determine the MMA major mode.
# We know Q is MN-major (row-major in the attention M×K layout), so a_major = MN.
# This is always true for our FMHA kernel's Q layout.
self.a_major = cute.nvgpu.OperandMajorMode.MN
self.c_layout = LayoutEnum.ROW_MAJOR # O is always row-major
else:
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)