Sweep test: n=128,256,384,512,1024

This commit is contained in:
2026-05-22 21:31:15 +00:00
parent beaf60db5c
commit a751b3baf7

View File

@@ -0,0 +1,45 @@
"""Test stage C at n=384, 512, 1024 to check if pipeline cycling works."""
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline, cutlass.torch as ct, cuda.bindings.driver as cuda
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import sys
sys.path.insert(0, '.')
from test_fmha_v3_stage_c import FmhaV3StageCMulti
HEAD_DIM = 64
for n in [128, 256, 384, 512, 1024]:
torch.manual_seed(42)
q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kf = k[:,:,0].float()
scale = 1.0/math.sqrt(HEAD_DIM)
ref = torch.softmax(qf @ kf.T * scale, dim=-1) @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaV3StageCMulti(s_k=n)
print(f'n={n}: Compiling...', flush=True)
try:
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'FMHA n={n} ({n//128} tiles): cos {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99 and cos > 0.01:
ratio = (out[0,:4] / ref[0,:4]).mean().item()
print(f' out[0,:4]={out[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
print(f' ratio out/ref: {ratio:.4f}')
except Exception as e:
print(f'FMHA n={n}: ERROR: {e}')