D1.4: Reduce pv_n_tile to 128 for hd=512 to fit SMEM budget (192KB)
This commit is contained in:
@@ -20,7 +20,12 @@ class FmhaKernel:
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
|
||||
self.pv_n_tile = min(head_dim, 256)
|
||||
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
|
||||
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
|
||||
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
|
||||
if head_dim > 256:
|
||||
self.pv_n_tile = 128
|
||||
self.n_pv_tiles = head_dim // self.pv_n_tile
|
||||
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
|
||||
self.normalize = normalize # D5a: False = emit un-normalized O + lse
|
||||
@@ -155,6 +160,8 @@ class FmhaKernel:
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
|
||||
# sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB.
|
||||
# TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput.
|
||||
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
|
||||
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
|
||||
# sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM)
|
||||
|
||||
71
tests/unit/test_smem_budget.py
Normal file
71
tests/unit/test_smem_budget.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Diagnostic: Print exact SMEM budget for various HEAD_DIM values."""
|
||||
import math
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def smem_budget(hd, s_k=128, use_smem_p=None):
|
||||
"""Reproduce FmhaKernel._setup SMEM allocation sizes."""
|
||||
kv_stage = 1 if hd > 128 else 2
|
||||
q_stage = 1
|
||||
k_tile = min(hd, 256)
|
||||
pv_n_tile = min(hd, 256)
|
||||
pv_mma_tiler = (128, pv_n_tile, k_tile)
|
||||
qk_mma_tiler = (128, 128, k_tile)
|
||||
|
||||
# Rough SMEM sizes (actual CuTe layouts have padding/alignment)
|
||||
# sQ: (M, K_tile) BF16 * q_stage
|
||||
sQ = 128 * k_tile * 2 * q_stage
|
||||
# sK: (K_tile, N) BF16 * kv_stage
|
||||
sK = k_tile * 128 * 2 * kv_stage
|
||||
# sV: (K_tile_V, N) BF16 * kv_stage, where K_tile_V = pv_n_tile
|
||||
sV = pv_n_tile * 128 * 2 * kv_stage
|
||||
# sC: (M, N) BF16 * num_c_stage
|
||||
num_c_stage = 1 if hd > 256 else 2
|
||||
sC = 128 * pv_n_tile * 2 * num_c_stage
|
||||
# sP: only if SMEM-P
|
||||
sP = 0
|
||||
if use_smem_p is None:
|
||||
use_smem_p = hd > 64
|
||||
if use_smem_p:
|
||||
sP = 128 * pv_n_tile * 2 * 1 # 1 stage
|
||||
|
||||
total = sQ + sK + sV + sC + sP
|
||||
limit = 232 * 1024
|
||||
|
||||
print(f"hd={hd}: k_tile={k_tile}, kv_stage={kv_stage}, pv_n_tile={pv_n_tile}, "
|
||||
f"num_c_stage={num_c_stage}, use_smem_p={use_smem_p}")
|
||||
print(f" sQ={sQ//1024}KB, sK={sK//1024}KB, sV={sV//1024}KB, sC={sC//1024}KB, sP={sP//1024}KB")
|
||||
print(f" Total={total//1024}KB vs {limit//1024}KB limit → {'✅ FITS' if total <= limit else '❌ OVER'}")
|
||||
print()
|
||||
|
||||
# Now compute with overlap: sQ/sV share same region
|
||||
if not use_smem_p:
|
||||
overlap_total = max(sQ, sV) + sK + sC
|
||||
print(f" With sQ/sV overlap: {overlap_total//1024}KB → {'✅ FITS' if overlap_total <= limit else '❌ OVER'}")
|
||||
else:
|
||||
overlap_total = max(sQ, sV) + sK + sC + sP
|
||||
print(f" With sQ/sV overlap (+sP): {overlap_total//1024}KB → {'✅ FITS' if overlap_total <= limit else '❌ OVER'}")
|
||||
|
||||
# Another option: also overlap sK with sC (K consumed before C written)
|
||||
# After QK GEMM, sK is done. Then PV starts, writes O to TMEM, then epilogue writes C.
|
||||
# But sK is needed for ALL kv_tiles... and sC is written after all PV is done.
|
||||
# Actually sK is needed per kv_tile, so it stays alive through the whole QK loop.
|
||||
# Overlap sK/sC doesn't work.
|
||||
|
||||
# Option: reduce pv_n_tile to 128
|
||||
pv_n_tile_small = 128
|
||||
sV_small = pv_n_tile_small * 128 * 2 * kv_stage
|
||||
sC_small = 128 * pv_n_tile_small * 2 * num_c_stage
|
||||
total_small = sQ + sK + sV_small + sC_small + sP
|
||||
print(f" With pv_n_tile=128: sV={sV_small//1024}KB, sC={sC_small//1024}KB, "
|
||||
f"Total={total_small//1024}KB → {'✅ FITS' if total_small <= limit else '❌ OVER'}")
|
||||
overlap_small = max(sQ, sV_small) + sK + sC_small
|
||||
print(f" With pv_n_tile=128 + sQ/sV overlap: {overlap_small//1024}KB → {'✅ FITS' if overlap_small <= limit else '❌ OVER'}")
|
||||
print()
|
||||
|
||||
|
||||
print("=== SMEM Budget Analysis ===\n")
|
||||
smem_budget(64)
|
||||
smem_budget(128)
|
||||
smem_budget(256)
|
||||
smem_budget(512)
|
||||
Reference in New Issue
Block a user