D1.5: Dynamic slicing for tTR_gC (variable rest dims)
This commit is contained in:
@@ -456,8 +456,13 @@ class FmhaKernel:
|
||||
# Partition source (TMEM) and destination (GMEM-derived register shape)
|
||||
tTR_tAcc = thr_t2r.partition_S(tOtO_epi)
|
||||
tTR_gC = thr_t2r.partition_D(tCgC_epi)
|
||||
# Register tensor shape: keep (T2R, T2R_M, T2R_N) dims, zero the rest
|
||||
# After partition_D, tTR_gC has (T2R, T2R_M, T2R_N, EPI_M, EPI_N, REST...)
|
||||
# We slice to get just the per-subtile register count.
|
||||
_tTR_gC_shape = tTR_gC.shape
|
||||
_n_rest = len(_tTR_gC_shape) - 3 # 3 leading dims: T2R, T2R_M, T2R_N
|
||||
tTR_rAcc = cute.make_rmem_tensor(
|
||||
tTR_gC[(None, None, None, 0, 0)].shape, self.acc_dtype)
|
||||
tTR_gC[(None, None, None) + (0,) * _n_rest].shape, self.acc_dtype)
|
||||
|
||||
# Step 3: SMEM store copy (epilogue_smem_copy_and_partition pattern)
|
||||
smem_copy_atom = get_smem_store_op(
|
||||
|
||||
Reference in New Issue
Block a user