fix: always provide valid gP tensor
This commit is contained in:
@@ -91,6 +91,10 @@ class FmhaKernel:
|
||||
@cute.jit
|
||||
def __call__(self, q, k, v, c, stream, lse=None, gP=None):
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
# If gP not provided, create a dummy tensor (for non-SMEM-P paths)
|
||||
if gP is None:
|
||||
_gP_dummy = torch.zeros(128, self.s_k, dtype=torch.bfloat16, device='cuda')
|
||||
gP = ct.from_dlpack(_gP_dummy).mark_layout_dynamic(leading_dim=ct.get_leading_dim(_gP_dummy))
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
v_fmha = cute.make_tensor(
|
||||
|
||||
Reference in New Issue
Block a user