fix: always provide valid gP tensor

This commit is contained in:
2026-05-24 03:37:43 +00:00
parent 49f54aef2d
commit 60cabb186d

View File

@@ -91,6 +91,10 @@ class FmhaKernel:
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None, gP=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
# If gP not provided, create a dummy tensor (for non-SMEM-P paths)
if gP is None:
_gP_dummy = torch.zeros(128, self.s_k, dtype=torch.bfloat16, device='cuda')
gP = ct.from_dlpack(_gP_dummy).mark_layout_dynamic(leading_dim=ct.get_leading_dim(_gP_dummy))
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(