D1.1: Fix PV A-operand construction — compile-time branch for TMEM vs SMEM
This commit is contained in:
@@ -166,14 +166,20 @@ class FmhaKernel:
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# PV A-operand: always define both TMEM and SMEM paths (CuTeDSL scoping)
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0),
|
||||
tOrP.layout)
|
||||
tCrP = pv_mma.make_fragment_A(sP)
|
||||
# PV A-operand: constructed based on use_smem_p (compile-time branch)
|
||||
if not self.use_smem_p:
|
||||
# TMEM-P: PV reads P from TMEM alias of QK C-fragment
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tOrP_base = pv_thr.make_fragment_A(tP)
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout)
|
||||
tCrP = pv_mma.make_fragment_A(sP) # dummy, never used
|
||||
else:
|
||||
# SMEM-P: PV reads P from SMEM
|
||||
tOrP0 = cute.make_tensor(tStS.iterator, tStS.layout) # dummy, never used
|
||||
tCrP = pv_mma.make_fragment_A(sP)
|
||||
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
Reference in New Issue
Block a user