D1.3: Fix critical bug - add TMEM column offset for P0 in PV GEMM
The softmax warps store P at tmem_p0_offset=32. PV MMA must read from the same offset. tOrP0 was missing the offset, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests.
This commit is contained in:
@@ -174,9 +174,16 @@ class FmhaKernel:
|
||||
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tCrP = pv_mma.make_fragment_A(sP)
|
||||
# tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies
|
||||
# the p0 column offset inline when constructing the gemm arguments.
|
||||
tOrP0 = tOrP
|
||||
# tOrP0: apply TMEM column offset for P0 (TMEM-P path only)
|
||||
# The softmax warps store P at tmem_p0_offset columns. PV MMA must read
|
||||
# from the same offset. For SMEM-P, tOrP is bound to sP (not TMEM).
|
||||
if not self.use_smem_p:
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout,
|
||||
)
|
||||
else:
|
||||
tOrP0 = tOrP
|
||||
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
Reference in New Issue
Block a user