D2: hardcode a_major=MN for multi-CTA (Q is always MN-major in FMHA)
This commit is contained in:
@@ -108,14 +108,13 @@ class FmhaKernel:
|
||||
def __call__(self, q, k, v, c, stream, lse=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
# For multi-CTA, q has shape (n_h, T, hd, 1). Layout depends on inner (T, hd) dims only.
|
||||
# LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, we extract the
|
||||
# inner dimensions by slicing away the head dim.
|
||||
# LayoutEnum.from_tensor needs 2D/3D tensors. For 4D Q, compute layout manually.
|
||||
if const_expr(self.num_ctas > 1):
|
||||
# q shape: (n_h, T, hd, 1). Slice head dim to get (T, hd, 1) view.
|
||||
q_inner = q[0] # select head 0 → (T, hd, 1)
|
||||
self.a_major = LayoutEnum.from_tensor(q_inner).mma_major_mode()
|
||||
c_inner = c[0] # select head 0 → (T, hd, 1)
|
||||
self.c_layout = LayoutEnum.from_tensor(c_inner)
|
||||
# q shape: (n_h, T, hd, 1). The inner (T, hd, 1) dims determine the MMA major mode.
|
||||
# We know Q is MN-major (row-major in the attention M×K layout), so a_major = MN.
|
||||
# This is always true for our FMHA kernel's Q layout.
|
||||
self.a_major = cute.nvgpu.OperandMajorMode.MN
|
||||
self.c_layout = LayoutEnum.ROW_MAJOR # O is always row-major
|
||||
else:
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
|
||||
Reference in New Issue
Block a user