D1.3: Use const_expr if for tOrP0 compile-time selection
This commit is contained in:
@@ -179,14 +179,11 @@ class FmhaKernel:
|
||||
tCrP = pv_mma.make_fragment_A(sP)
|
||||
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
|
||||
# self.tOrP0_offset is pre-computed in _setup as a Python int.
|
||||
# We must avoid tOrP.iterator + 0 (MLIR OpResult + int not supported).
|
||||
# Use const_expr to select the right construction at compile time.
|
||||
@const_expr
|
||||
def _make_tOrP0():
|
||||
if self.tOrP0_offset > 0:
|
||||
return cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
|
||||
return tOrP
|
||||
tOrP0 = _make_tOrP0()
|
||||
# Use const_expr if/else for compile-time conditional.
|
||||
if const_expr(self.tOrP0_offset > 0):
|
||||
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
|
||||
else:
|
||||
tOrP0 = tOrP
|
||||
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
Reference in New Issue
Block a user