From a751b3baf7f879ff55a04e973c057cf018c5e885 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:31:15 +0000 Subject: [PATCH] Sweep test: n=128,256,384,512,1024 --- tests/unit/test_stage_c_sweep.py | 45 ++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/unit/test_stage_c_sweep.py diff --git a/tests/unit/test_stage_c_sweep.py b/tests/unit/test_stage_c_sweep.py new file mode 100644 index 00000000..a2455435 --- /dev/null +++ b/tests/unit/test_stage_c_sweep.py @@ -0,0 +1,45 @@ +"""Test stage C at n=384, 512, 1024 to check if pipeline cycling works.""" +import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline, cutlass.torch as ct, cuda.bindings.driver as cuda +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass import Float32, BFloat16, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset +import sys +sys.path.insert(0, '.') +from test_fmha_v3_stage_c import FmhaV3StageCMulti + +HEAD_DIM = 64 +for n in [128, 256, 384, 512, 1024]: + torch.manual_seed(42) + q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + v_kernel = v.unsqueeze(-1) + c = torch.zeros(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + + qf = q[:,:,0].float(); kf = k[:,:,0].float() + scale = 1.0/math.sqrt(HEAD_DIM) + ref = torch.softmax(qf @ kf.T * scale, dim=-1) @ v.float() + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = FmhaV3StageCMulti(s_k=n) + print(f'n={n}: Compiling...', flush=True) + try: + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + out = c[:,:,0].float() + cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item() + print(f'FMHA n={n} ({n//128} tiles): cos {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') + if cos < 0.99 and cos > 0.01: + ratio = (out[0,:4] / ref[0,:4]).mean().item() + print(f' out[0,:4]={out[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') + print(f' ratio out/ref: {ratio:.4f}') + except Exception as e: + print(f'FMHA n={n}: ERROR: {e}')