From 044c230760ac8f93ba9158f3e3ca191eb9454404 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:40:44 +0000 Subject: [PATCH] D1.5: Dynamic slicing for tTR_gC (variable rest dims) --- dsv4/kernels/attention/fmha.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 6c6127ac..be2b6085 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -456,8 +456,13 @@ class FmhaKernel: # Partition source (TMEM) and destination (GMEM-derived register shape) tTR_tAcc = thr_t2r.partition_S(tOtO_epi) tTR_gC = thr_t2r.partition_D(tCgC_epi) + # Register tensor shape: keep (T2R, T2R_M, T2R_N) dims, zero the rest + # After partition_D, tTR_gC has (T2R, T2R_M, T2R_N, EPI_M, EPI_N, REST...) + # We slice to get just the per-subtile register count. + _tTR_gC_shape = tTR_gC.shape + _n_rest = len(_tTR_gC_shape) - 3 # 3 leading dims: T2R, T2R_M, T2R_N tTR_rAcc = cute.make_rmem_tensor( - tTR_gC[(None, None, None, 0, 0)].shape, self.acc_dtype) + tTR_gC[(None, None, None) + (0,) * _n_rest].shape, self.acc_dtype) # Step 3: SMEM store copy (epilogue_smem_copy_and_partition pattern) smem_copy_atom = get_smem_store_op(