diff --git a/tests/unit/test_d1_3_smem_vs_tmem.py b/tests/unit/test_d1_3_smem_vs_tmem.py new file mode 100644 index 00000000..ff956e09 --- /dev/null +++ b/tests/unit/test_d1_3_smem_vs_tmem.py @@ -0,0 +1,88 @@ +""" +D1.3 SMEM-P: Focused test of coordinate-indexed SMEM-P write. +Tests at hd=64 with SMEM-P, compares to TMEM-P result. +Also tests hd=128 and hd=256. +""" +import torch, math +import cutlass, cutlass.cute as cute, cutlass.utils as utils +from cutlass.cute.nvgpu import tcgen05 +from cutlass import Float32, BFloat16 +from cutlass.utils import LayoutEnum +import cutlass.torch as ct +import cuda.bindings.driver as cuda + +from dsv4.kernels.attention.fmha import FmhaKernel + + +def test_hd(hd, use_smem_p, s_k=128): + pv_n = min(hd, 256) + q = torch.randn(128, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(s_k, pv_n, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(128, pv_n, 1, dtype=torch.bfloat16, device='cuda') + + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + vf = v.float() + scale = 1.0 / math.sqrt(hd) + attn = qf @ kf.T * scale + ref = torch.softmax(attn, dim=-1) @ vf + + kern = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=True) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + v_tile = v.unsqueeze(-1) + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + + mode = "SMEM-P" if use_smem_p else "TMEM-P" + print(f'Compiling hd={hd} {mode}...', flush=True) + compiled = cute.compile(kern, mQ, mK, mV, mC, stream) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + + out = c[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + ).item() + max_abs = (out - ref).abs().max().item() + has_nan = torch.isnan(out).any().item() + has_inf = torch.isinf(out).any().item() + + print(f' hd={hd} {mode}: cos={cos:.6f} max_abs={max_abs:.6f} NaN={has_nan} Inf={has_inf}') + + # Print first few values for comparison + print(f' out[0,:4]={out[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') + print() + + return cos + + +if __name__ == '__main__': + print("=== SMEM-P vs TMEM-P Comparison ===\n") + + # Baseline: TMEM-P at hd=64 (proven path) + cos_tmem = test_hd(64, use_smem_p=False) + + # Test: SMEM-P at hd=64 (should match TMEM-P if correct) + cos_smem = test_hd(64, use_smem_p=True) + + # Test: SMEM-P at hd=128 + cos_128 = test_hd(128, use_smem_p=True) + + # Test: SMEM-P at hd=256 + cos_256 = test_hd(256, use_smem_p=True) + + print("=== Summary ===") + print(f"hd=64 TMEM-P: cos={cos_tmem:.6f} (baseline)") + print(f"hd=64 SMEM-P: cos={cos_smem:.6f} (should match TMEM-P)") + print(f"hd=128 SMEM-P: cos={cos_128:.6f}") + print(f"hd=256 SMEM-P: cos={cos_256:.6f}") + + if abs(cos_tmem - cos_smem) < 0.01: + print("\n✅ SMEM-P matches TMEM-P at hd=64 — coordinate mapping is correct!") + else: + print(f"\n⚠️ SMEM-P differs from TMEM-P by {abs(cos_tmem - cos_smem):.6f} — coordinate mapping has issues")