diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 930bca32..5c5c045a 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -107,7 +107,16 @@ class FmhaKernel: @cute.jit def __call__(self, q, k, v, c, stream, lse=None): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype - self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + # For multi-CTA, q has shape (n_h, T, hd, 1). Layout depends on inner (T, hd) dims only. + # Create a view of q without the head dimension for layout computation. + if const_expr(self.num_ctas > 1): + q_inner = cute.local_tile(q, cute.make_layout(1, stride=1), (0,)) # select head 0's view + self.a_major = LayoutEnum.from_tensor(q_inner).mma_major_mode() + c_inner = cute.local_tile(c, cute.make_layout(1, stride=1), (0,)) + self.c_layout = LayoutEnum.from_tensor(c_inner) + else: + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.c_layout = LayoutEnum.from_tensor(c) self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() v_fmha = cute.make_tensor( v.iterator,