From 295e5a8c2f85e908ea0f57d2f6d6c6dbfdf0b4d0 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 21:02:04 +0000 Subject: [PATCH] D1.3: Fix tOrP0 offset - scale FP32 columns to BF16 elements tmem_p0_offset is in FP32 columns, but tOrP uses BF16 elements. Offset = p0_offset * (32/16) = p0_offset * 2. --- dsv4/kernels/attention/fmha.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index cba1b032..36140616 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -175,11 +175,13 @@ class FmhaKernel: tOrP = tOrP_base[(None,None,None,0)] tCrP = pv_mma.make_fragment_A(sP) # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). - # The softmax warps store P at tmem_p0_offset columns. PV MMA must read - # from the same offset. For SMEM-P, tOrP is bound to sP (not TMEM). - # Must be defined unconditionally (CuTeDSL scoping: no if/else for variables). - p0_col_offset = self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0) - tOrP0 = cute.make_tensor(tOrP.iterator + p0_col_offset, tOrP.layout) + # The softmax warps store P at tmem_p0_offset FP32 columns. PV MMA's + # tOrP fragment uses BF16 elements. Offset in BF16 elements = + # tmem_p0_offset * (FP32_width / BF16_width) = offset * 2. + # For SMEM-P, offset is 0 (P not in TMEM). + # Must be defined unconditionally (CuTeDSL scoping). + _p0_bf16_offset = max(self.tmem_p0_offset, 0) * (32 // 16) # Python int + tOrP0 = cute.make_tensor(tOrP.iterator + _p0_bf16_offset, tOrP.layout) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)