From 972fbd48b9eb3dd52aaa718b7d23f1fdacec663e Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 21:06:16 +0000 Subject: [PATCH] D1.3: Use const_expr for tOrP0 offset (compile-time conditional) --- dsv4/kernels/attention/fmha.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 1c376ecd..ca0ad7c0 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -175,15 +175,11 @@ class FmhaKernel: tOrP = tOrP_base[(None,None,None,0)] tCrP = pv_mma.make_fragment_A(sP) # tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path). - # Same pattern as Stage C kernel: uses MLIR-compatible arithmetic. - # tmem_p0_offset is in FP32 columns; tOrP uses BF16 elements. - # Offset = acc_width / q_width * tmem_p0_offset. - # For SMEM-P (tmem_p0_offset=-1), tOrP0 is unused by the MMA warp - # but must be valid for CuTeDSL compilation. - tOrP0 = cute.make_tensor( - tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, - tOrP.layout, - ) + # Use const_expr to handle TMEM-P vs SMEM-P at compile time. + # For TMEM-P: apply P0 offset (BF16 elements = p0_offset * 2) + # For SMEM-P: no offset needed (tOrP0 unused by MMA warp) + _p0_bf16_offset = max(self.tmem_p0_offset, 0) * 2 # Python int, computed at JIT time + tOrP0 = const_expr(lambda: cute.make_tensor(tOrP.iterator + _p0_bf16_offset, tOrP.layout) if _p0_bf16_offset > 0 else tOrP) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)