diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index d1d356b9..d84c773e 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -327,35 +327,37 @@ class FmhaV3StageC: def test(): torch.manual_seed(42) for n in [128]: - m, hd = 128, HEAD_DIM - q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') - k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') - v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') - v_kernel = v.unsqueeze(-1) - c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') - # Reference: softmax(Q @ K^T / sqrt(d)) @ V - qf = q[:,:,0].float(); kf = k[:,:,0].float() - scale = 1.0 / math.sqrt(hd) - attn = qf @ kf.T * scale - attn = torch.softmax(attn, dim=-1) - ref = attn @ v.float() - mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) - mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) - mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) - mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - kernel = FmhaV3StageC() - print(f'n={n}: Compiling...', flush=True) - compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) - print(f'n={n}: tmem_offsets: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols}', flush=True) - print(f'n={n}: Running...', flush=True) - compiled(mQ, mK, mV, mC, stream) - torch.cuda.synchronize() - out = c[:,:,0].float() - cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() - print(f'FMHA Stage-C n={n}: cosine {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') - if cos < 0.99: - print(f' out[0,:4]={out[0,:4].tolist()} ref[0,:4]={ref[0,:4].tolist()}') + for seed in [42, 123, 999]: + torch.manual_seed(seed) + m, hd = 128, HEAD_DIM + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') + v_kernel = v.unsqueeze(-1) + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + qf = q[:,:,0].float(); kf = k[:,:,0].float() + scale = 1.0 / math.sqrt(hd) + attn = qf @ kf.T * scale + attn = torch.softmax(attn, dim=-1) + ref = attn @ v.float() + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + kernel = FmhaV3StageC() + if seed == 42: + print(f'seed={seed}: Compiling...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + if seed == 42: + print(f'tmem_offsets: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols}', flush=True) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + print(f'FMHA Stage-C n={n} seed={seed}: cosine {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') + if cos < 0.99: + print(f' out[0,:4]={out[0,:4].tolist()} ref[0,:4]={ref[0,:4].tolist()}') if __name__ == '__main__': test()