From 58ca480fd118d37ba648b2920e5263d6421c09e1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 21 May 2026 17:49:26 +0000 Subject: [PATCH] Stage C: add validation harness with real softmax reference (C1) --- tests/unit/test_fmha_v3_softmax.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/unit/test_fmha_v3_softmax.py b/tests/unit/test_fmha_v3_softmax.py index 166dc904..ddd84e8e 100644 --- a/tests/unit/test_fmha_v3_softmax.py +++ b/tests/unit/test_fmha_v3_softmax.py @@ -381,3 +381,41 @@ class FmhaV3Softmax: c_pipe.producer_tail() tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) + +def test(): + """C1 validation harness: real softmax reference.""" + import math + torch.manual_seed(42) + for n in [128, 256, 384]: + m, hd = 128, HEAD_DIM + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') + v_kernel = v.unsqueeze(-1) + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + # Real softmax reference + qf = q[:,:,0].float() + kf = k[:,:,0].float() + attn = qf @ kf.T / math.sqrt(hd) + ref = torch.softmax(attn, dim=-1) @ v.float() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = FmhaV3Softmax() + print(f'n={n}: Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + print(f'n={n}: tmem: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols}', flush=True) + print(f'n={n}: Running...', flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + max_err = (out - ref).abs().max().item() + print(f'FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {"PASS" if cos >= 0.999 else "FAIL"}', flush=True) + if cos < 0.999: + print(f' out[0,:4]={out[0,:4].tolist()} ref[0,:4]={ref[0,:4].tolist()}') + +if __name__ == '__main__': + test()