Files
nvfp4-megamoe-kernel/tests/unit/test_d1_diag.py
2026-05-23 03:25:29 +00:00

93 lines
3.7 KiB
Python

"""D1 diagnostic: Print TMEM budget and shapes at hd=256."""
import torch, math
import cutlass, cutlass.cute as cute
import cutlass.torch as ct
import cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
hd = 256
n_kv = 128
m = 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
q_dtype = cutlass.BFloat16; qk_acc_dtype = cutlass.Float32
cta_group = tcgen05.CtaGroup.ONE
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
a_major = LayoutEnum.from_tensor(mQ).mma_major_mode()
b_major = LayoutEnum.from_tensor(mK).mma_major_mode()
pv_n_tile = min(hd, 256)
v_fmha = cute.make_tensor(
mV.iterator,
cute.make_layout((pv_n_tile, n_kv, 1), stride=(1, pv_n_tile, pv_n_tile * n_kv)),
)
v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
print(f"a_major={a_major}, b_major={b_major}, v_major={v_major}")
print(f"Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}")
print(f"mQ shape: {mQ.shape}, mK shape: {mK.shape}, mV shape: {mV.shape}, mC shape: {mC.shape}")
qk_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, a_major, b_major, qk_acc_dtype, cta_group, (128, 128), tcgen05.OperandSource.SMEM)
print(f"\nQK MMA shape_mnk: {qk_mma.shape_mnk}")
pv_mma = utils.sm100.make_trivial_tiled_mma(q_dtype, q_dtype, cutlass.nvgpu.OperandMajorMode.K, v_major, qk_acc_dtype, cta_group, (128, pv_n_tile), tcgen05.OperandSource.TMEM)
print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}")
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_k_inner = pv_ik * (128 // pv_ik)
pv_mma_tiler = (128, pv_n_tile, pv_k_inner)
print(f"qk_ik={qk_ik}, qk_mma_tiler={qk_mma_tiler}")
print(f"pv_ik={pv_ik}, pv_mma_tiler={pv_mma_tiler}")
# GMEM tile views
gQ = cute.local_tile(mQ, cute.slice_(qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC, cute.slice_(pv_mma_tiler,(None,None,0)),(None,None,None))
print(f"\ngQ shape: {gQ.shape}")
print(f"gK shape: {gK.shape}")
print(f"gV shape: {gV.shape}")
print(f"gC shape: {gC.shape}")
print(f"n_kv_tiles: {cute.size(gK, mode=[3])}")
# TMEM budget
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tmem_p0_offset = 32
p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width
o_cols = find_tmem_tensor_col_offset(tOtO)
s_cols = qk_mma_tiler[1]
p_end = tmem_p0_offset + p_cols_fp32
o_after = max(s_cols, p_end)
tmem_o0_offset = ((o_after + 31) // 32) * 32
total = tmem_o0_offset + o_cols
print(f"\n=== TMEM Budget (hd={hd}) ===")
print(f"S={s_cols}, P_fp32={p_cols_fp32}, P_end={p_end}, O_offset={tmem_o0_offset}, O_cols={o_cols}")
print(f"Total={total}, Fits={total<=512}")
# QK C-fragment info
print(f"\ntStS shape: {tStS.shape}, cosize: {cute.cosize(tStS.layout)}")
print(f"tOtO shape: {tOtO.shape}, cosize: {cute.cosize(tOtO.layout)}")