D1.5: Dynamic slicing for tTR_gC (variable rest dims)

This commit is contained in:
2026-05-24 01:40:44 +00:00
parent 1bde12782d
commit 044c230760

View File

@@ -456,8 +456,13 @@ class FmhaKernel:
# Partition source (TMEM) and destination (GMEM-derived register shape)
tTR_tAcc = thr_t2r.partition_S(tOtO_epi)
tTR_gC = thr_t2r.partition_D(tCgC_epi)
# Register tensor shape: keep (T2R, T2R_M, T2R_N) dims, zero the rest
# After partition_D, tTR_gC has (T2R, T2R_M, T2R_N, EPI_M, EPI_N, REST...)
# We slice to get just the per-subtile register count.
_tTR_gC_shape = tTR_gC.shape
_n_rest = len(_tTR_gC_shape) - 3 # 3 leading dims: T2R, T2R_M, T2R_N
tTR_rAcc = cute.make_rmem_tensor(
tTR_gC[(None, None, None, 0, 0)].shape, self.acc_dtype)
tTR_gC[(None, None, None) + (0,) * _n_rest].shape, self.acc_dtype)
# Step 3: SMEM store copy (epilogue_smem_copy_and_partition pattern)
smem_copy_atom = get_smem_store_op(