D1.3: Fix tOrP0 offset - scale FP32 columns to BF16 elements

tmem_p0_offset is in FP32 columns, but tOrP uses BF16 elements.
Offset = p0_offset * (32/16) = p0_offset * 2.
This commit is contained in:
2026-05-23 21:02:04 +00:00
parent f3f2ab4b50
commit 295e5a8c2f

View File

@@ -175,11 +175,13 @@ class FmhaKernel:
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP)
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
# The softmax warps store P at tmem_p0_offset columns. PV MMA must read
# from the same offset. For SMEM-P, tOrP is bound to sP (not TMEM).
# Must be defined unconditionally (CuTeDSL scoping: no if/else for variables).
p0_col_offset = self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0)
tOrP0 = cute.make_tensor(tOrP.iterator + p0_col_offset, tOrP.layout)
# The softmax warps store P at tmem_p0_offset FP32 columns. PV MMA's
# tOrP fragment uses BF16 elements. Offset in BF16 elements =
# tmem_p0_offset * (FP32_width / BF16_width) = offset * 2.
# For SMEM-P, offset is 0 (P not in TMEM).
# Must be defined unconditionally (CuTeDSL scoping).
_p0_bf16_offset = max(self.tmem_p0_offset, 0) * (32 // 16) # Python int
tOrP0 = cute.make_tensor(tOrP.iterator + _p0_bf16_offset, tOrP.layout)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)