diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 5f6fdc98..247e58c5 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -108,14 +108,13 @@ class FmhaKernel: def __call__(self, q, k, v, c, stream, lse=None): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype # For multi-CTA, q has shape (n_h, T, hd, 1). Layout depends on inner (T, hd) dims only. - # LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, we extract the - # inner dimensions by slicing away the head dim. + # LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, compute layout manually. if const_expr(self.num_ctas > 1): - # q shape: (n_h, T, hd, 1). Slice head dim to get (T, hd, 1) view. - q_inner = q[0] # select head 0 → (T, hd, 1) - self.a_major = LayoutEnum.from_tensor(q_inner).mma_major_mode() - c_inner = c[0] # select head 0 → (T, hd, 1) - self.c_layout = LayoutEnum.from_tensor(c_inner) + # q shape: (n_h, T, hd, 1). The inner (T, hd, 1) dims determine the MMA major mode. + # We know Q is MN-major (row-major in the attention M×K layout), so a_major = MN. + # This is always true for our FMHA kernel's Q layout. + self.a_major = cute.nvgpu.OperandMajorMode.MN + self.c_layout = LayoutEnum.ROW_MAJOR # O is always row-major else: self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c)