diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 2402a8b8..b2c002e7 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -20,7 +20,12 @@ class FmhaKernel: self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 - self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256 + self.pv_n_tile = min(head_dim, 256) + # At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB, + # making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512 + # (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256. + if head_dim > 256: + self.pv_n_tile = 128 self.n_pv_tiles = head_dim // self.pv_n_tile self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64) self.normalize = normalize # D5a: False = emit un-normalized O + lse @@ -155,6 +160,8 @@ class FmhaKernel: sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner) sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner) + # sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB. + # TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput. sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner) sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner) # sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM) diff --git a/tests/unit/test_smem_budget.py b/tests/unit/test_smem_budget.py new file mode 100644 index 00000000..7acd0ad1 --- /dev/null +++ b/tests/unit/test_smem_budget.py @@ -0,0 +1,71 @@ +"""Diagnostic: Print exact SMEM budget for various HEAD_DIM values.""" +import math +from dsv4.kernels.attention.fmha import FmhaKernel + + +def smem_budget(hd, s_k=128, use_smem_p=None): + """Reproduce FmhaKernel._setup SMEM allocation sizes.""" + kv_stage = 1 if hd > 128 else 2 + q_stage = 1 + k_tile = min(hd, 256) + pv_n_tile = min(hd, 256) + pv_mma_tiler = (128, pv_n_tile, k_tile) + qk_mma_tiler = (128, 128, k_tile) + + # Rough SMEM sizes (actual CuTe layouts have padding/alignment) + # sQ: (M, K_tile) BF16 * q_stage + sQ = 128 * k_tile * 2 * q_stage + # sK: (K_tile, N) BF16 * kv_stage + sK = k_tile * 128 * 2 * kv_stage + # sV: (K_tile_V, N) BF16 * kv_stage, where K_tile_V = pv_n_tile + sV = pv_n_tile * 128 * 2 * kv_stage + # sC: (M, N) BF16 * num_c_stage + num_c_stage = 1 if hd > 256 else 2 + sC = 128 * pv_n_tile * 2 * num_c_stage + # sP: only if SMEM-P + sP = 0 + if use_smem_p is None: + use_smem_p = hd > 64 + if use_smem_p: + sP = 128 * pv_n_tile * 2 * 1 # 1 stage + + total = sQ + sK + sV + sC + sP + limit = 232 * 1024 + + print(f"hd={hd}: k_tile={k_tile}, kv_stage={kv_stage}, pv_n_tile={pv_n_tile}, " + f"num_c_stage={num_c_stage}, use_smem_p={use_smem_p}") + print(f" sQ={sQ//1024}KB, sK={sK//1024}KB, sV={sV//1024}KB, sC={sC//1024}KB, sP={sP//1024}KB") + print(f" Total={total//1024}KB vs {limit//1024}KB limit → {'✅ FITS' if total <= limit else '❌ OVER'}") + print() + + # Now compute with overlap: sQ/sV share same region + if not use_smem_p: + overlap_total = max(sQ, sV) + sK + sC + print(f" With sQ/sV overlap: {overlap_total//1024}KB → {'✅ FITS' if overlap_total <= limit else '❌ OVER'}") + else: + overlap_total = max(sQ, sV) + sK + sC + sP + print(f" With sQ/sV overlap (+sP): {overlap_total//1024}KB → {'✅ FITS' if overlap_total <= limit else '❌ OVER'}") + + # Another option: also overlap sK with sC (K consumed before C written) + # After QK GEMM, sK is done. Then PV starts, writes O to TMEM, then epilogue writes C. + # But sK is needed for ALL kv_tiles... and sC is written after all PV is done. + # Actually sK is needed per kv_tile, so it stays alive through the whole QK loop. + # Overlap sK/sC doesn't work. + + # Option: reduce pv_n_tile to 128 + pv_n_tile_small = 128 + sV_small = pv_n_tile_small * 128 * 2 * kv_stage + sC_small = 128 * pv_n_tile_small * 2 * num_c_stage + total_small = sQ + sK + sV_small + sC_small + sP + print(f" With pv_n_tile=128: sV={sV_small//1024}KB, sC={sC_small//1024}KB, " + f"Total={total_small//1024}KB → {'✅ FITS' if total_small <= limit else '❌ OVER'}") + overlap_small = max(sQ, sV_small) + sK + sC_small + print(f" With pv_n_tile=128 + sQ/sV overlap: {overlap_small//1024}KB → {'✅ FITS' if overlap_small <= limit else '❌ OVER'}") + print() + + +print("=== SMEM Budget Analysis ===\n") +smem_budget(64) +smem_budget(128) +smem_budget(256) +smem_budget(512)