D1.4: Reduce pv_n_tile to 128 for hd=512 to fit SMEM budget (192KB)

This commit is contained in:
2026-05-24 08:07:32 +00:00
parent bc8331c9eb
commit 090acfc0ce
2 changed files with 79 additions and 1 deletions

View File

@@ -20,7 +20,12 @@ class FmhaKernel:
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
self.pv_n_tile = min(head_dim, 256)
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
if head_dim > 256:
self.pv_n_tile = 128
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.normalize = normalize # D5a: False = emit un-normalized O + lse
@@ -155,6 +160,8 @@ class FmhaKernel:
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
# sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB.
# TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput.
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
# sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM)

View File

@@ -0,0 +1,71 @@
"""Diagnostic: Print exact SMEM budget for various HEAD_DIM values."""
import math
from dsv4.kernels.attention.fmha import FmhaKernel
def smem_budget(hd, s_k=128, use_smem_p=None):
"""Reproduce FmhaKernel._setup SMEM allocation sizes."""
kv_stage = 1 if hd > 128 else 2
q_stage = 1
k_tile = min(hd, 256)
pv_n_tile = min(hd, 256)
pv_mma_tiler = (128, pv_n_tile, k_tile)
qk_mma_tiler = (128, 128, k_tile)
# Rough SMEM sizes (actual CuTe layouts have padding/alignment)
# sQ: (M, K_tile) BF16 * q_stage
sQ = 128 * k_tile * 2 * q_stage
# sK: (K_tile, N) BF16 * kv_stage
sK = k_tile * 128 * 2 * kv_stage
# sV: (K_tile_V, N) BF16 * kv_stage, where K_tile_V = pv_n_tile
sV = pv_n_tile * 128 * 2 * kv_stage
# sC: (M, N) BF16 * num_c_stage
num_c_stage = 1 if hd > 256 else 2
sC = 128 * pv_n_tile * 2 * num_c_stage
# sP: only if SMEM-P
sP = 0
if use_smem_p is None:
use_smem_p = hd > 64
if use_smem_p:
sP = 128 * pv_n_tile * 2 * 1 # 1 stage
total = sQ + sK + sV + sC + sP
limit = 232 * 1024
print(f"hd={hd}: k_tile={k_tile}, kv_stage={kv_stage}, pv_n_tile={pv_n_tile}, "
f"num_c_stage={num_c_stage}, use_smem_p={use_smem_p}")
print(f" sQ={sQ//1024}KB, sK={sK//1024}KB, sV={sV//1024}KB, sC={sC//1024}KB, sP={sP//1024}KB")
print(f" Total={total//1024}KB vs {limit//1024}KB limit → {'✅ FITS' if total <= limit else '❌ OVER'}")
print()
# Now compute with overlap: sQ/sV share same region
if not use_smem_p:
overlap_total = max(sQ, sV) + sK + sC
print(f" With sQ/sV overlap: {overlap_total//1024}KB → {'✅ FITS' if overlap_total <= limit else '❌ OVER'}")
else:
overlap_total = max(sQ, sV) + sK + sC + sP
print(f" With sQ/sV overlap (+sP): {overlap_total//1024}KB → {'✅ FITS' if overlap_total <= limit else '❌ OVER'}")
# Another option: also overlap sK with sC (K consumed before C written)
# After QK GEMM, sK is done. Then PV starts, writes O to TMEM, then epilogue writes C.
# But sK is needed for ALL kv_tiles... and sC is written after all PV is done.
# Actually sK is needed per kv_tile, so it stays alive through the whole QK loop.
# Overlap sK/sC doesn't work.
# Option: reduce pv_n_tile to 128
pv_n_tile_small = 128
sV_small = pv_n_tile_small * 128 * 2 * kv_stage
sC_small = 128 * pv_n_tile_small * 2 * num_c_stage
total_small = sQ + sK + sV_small + sC_small + sP
print(f" With pv_n_tile=128: sV={sV_small//1024}KB, sC={sC_small//1024}KB, "
f"Total={total_small//1024}KB → {'✅ FITS' if total_small <= limit else '❌ OVER'}")
overlap_small = max(sQ, sV_small) + sK + sC_small
print(f" With pv_n_tile=128 + sQ/sV overlap: {overlap_small//1024}KB → {'✅ FITS' if overlap_small <= limit else '❌ OVER'}")
print()
print("=== SMEM Budget Analysis ===\n")
smem_budget(64)
smem_budget(128)
smem_budget(256)
smem_budget(512)