D1.3: Use const_expr for tOrP0 offset (compile-time conditional)
This commit is contained in:
@@ -175,15 +175,11 @@ class FmhaKernel:
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tCrP = pv_mma.make_fragment_A(sP)
|
||||
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
|
||||
# Same pattern as Stage C kernel: uses MLIR-compatible arithmetic.
|
||||
# tmem_p0_offset is in FP32 columns; tOrP uses BF16 elements.
|
||||
# Offset = acc_width / q_width * tmem_p0_offset.
|
||||
# For SMEM-P (tmem_p0_offset=-1), tOrP0 is unused by the MMA warp
|
||||
# but must be valid for CuTeDSL compilation.
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
||||
tOrP.layout,
|
||||
)
|
||||
# Use const_expr to handle TMEM-P vs SMEM-P at compile time.
|
||||
# For TMEM-P: apply P0 offset (BF16 elements = p0_offset * 2)
|
||||
# For SMEM-P: no offset needed (tOrP0 unused by MMA warp)
|
||||
_p0_bf16_offset = max(self.tmem_p0_offset, 0) * 2 # Python int, computed at JIT time
|
||||
tOrP0 = const_expr(lambda: cute.make_tensor(tOrP.iterator + _p0_bf16_offset, tOrP.layout) if _p0_bf16_offset > 0 else tOrP)
|
||||
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
Reference in New Issue
Block a user