D1.3: Use const_expr for tOrP0 offset (compile-time conditional)

This commit is contained in:
2026-05-23 21:06:16 +00:00
parent a762820352
commit 972fbd48b9

View File

@@ -175,15 +175,11 @@ class FmhaKernel:
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP)
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
# Same pattern as Stage C kernel: uses MLIR-compatible arithmetic.
# tmem_p0_offset is in FP32 columns; tOrP uses BF16 elements.
# Offset = acc_width / q_width * tmem_p0_offset.
# For SMEM-P (tmem_p0_offset=-1), tOrP0 is unused by the MMA warp
# but must be valid for CuTeDSL compilation.
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout,
)
# Use const_expr to handle TMEM-P vs SMEM-P at compile time.
# For TMEM-P: apply P0 offset (BF16 elements = p0_offset * 2)
# For SMEM-P: no offset needed (tOrP0 unused by MMA warp)
_p0_bf16_offset = max(self.tmem_p0_offset, 0) * 2 # Python int, computed at JIT time
tOrP0 = const_expr(lambda: cute.make_tensor(tOrP.iterator + _p0_bf16_offset, tOrP.layout) if _p0_bf16_offset > 0 else tOrP)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)