D2: fix LayoutEnum for multi-dim Q (use head-0 view for layout)
This commit is contained in:
@@ -107,7 +107,16 @@ class FmhaKernel:
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream, lse=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
# For multi-CTA, q has shape (n_h, T, hd, 1). Layout depends on inner (T, hd) dims only.
|
||||
# Create a view of q without the head dimension for layout computation.
|
||||
if const_expr(self.num_ctas > 1):
|
||||
q_inner = cute.local_tile(q, cute.make_layout(1, stride=1), (0,)) # select head 0's view
|
||||
self.a_major = LayoutEnum.from_tensor(q_inner).mma_major_mode()
|
||||
c_inner = cute.local_tile(c, cute.make_layout(1, stride=1), (0,))
|
||||
self.c_layout = LayoutEnum.from_tensor(c_inner)
|
||||
else:
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
v_fmha = cute.make_tensor(
|
||||
v.iterator,
|
||||
|
||||
Reference in New Issue
Block a user