From df3146eb53952d5e53b11528ff45f9a7e72c26f9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 23:35:49 +0000 Subject: [PATCH] D2: hardcode a_major=MN for multi-CTA (Q is always MN-major in FMHA) --- dsv4/kernels/attention/fmha.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 5f6fdc98..247e58c5 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -108,14 +108,13 @@ class FmhaKernel: def __call__(self, q, k, v, c, stream, lse=None): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype # For multi-CTA, q has shape (n_h, T, hd, 1). Layout depends on inner (T, hd) dims only. - # LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, we extract the - # inner dimensions by slicing away the head dim. + # LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, compute layout manually. if const_expr(self.num_ctas > 1): - # q shape: (n_h, T, hd, 1). Slice head dim to get (T, hd, 1) view. - q_inner = q[0] # select head 0 → (T, hd, 1) - self.a_major = LayoutEnum.from_tensor(q_inner).mma_major_mode() - c_inner = c[0] # select head 0 → (T, hd, 1) - self.c_layout = LayoutEnum.from_tensor(c_inner) + # q shape: (n_h, T, hd, 1). The inner (T, hd, 1) dims determine the MMA major mode. + # We know Q is MN-major (row-major in the attention M×K layout), so a_major = MN. + # This is always true for our FMHA kernel's Q layout. + self.a_major = cute.nvgpu.OperandMajorMode.MN + self.c_layout = LayoutEnum.ROW_MAJOR # O is always row-major else: self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c)