"""Test the CUTLASS reference Blackwell FMHA on the B200. Does it actually work multi-tile?""" import sys sys.path.insert(0, '/root/cutlass/examples/python/CuTeDSL') import torch import math import cutlass import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from cute.blackwell.kernel.attention.fmha.fmha import ( BlackwellFusedMultiHeadAttentionForward, FusedMask, FusedMaskScale, FMHA_OperandMajorMode ) from cutlass.utils import LayoutEnum def test_reference(): HEAD_DIM = 64 for n in [128, 256, 512]: torch.manual_seed(42) m = 128 batch = 1 q = torch.randn(batch, 1, m, HEAD_DIM, dtype=torch.bfloat16, device='cuda') k = torch.randn(batch, 1, n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') v = torch.randn(batch, 1, n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') c = torch.zeros(batch, 1, m, HEAD_DIM, dtype=torch.bfloat16, device='cuda') # Reference: PyTorch softmax attention qf = q[0, 0].float() kf = k[0, 0].float() vf = v[0, 0].float() scale = 1.0 / math.sqrt(HEAD_DIM) attn = qf @ kf.T * scale attn = torch.softmax(attn, dim=-1) ref = attn @ vf stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = BlackwellFusedMultiHeadAttentionForward( q_major_mode=FMHA_OperandMajorMode.K, k_major_mode=FMHA_OperandMajorMode.K, v_major_mode=FMHA_OperandMajorMode.MN, o_major_mode=FMHA_OperandMajorMode.K, q_head_dim=HEAD_DIM, kv_head_dim=HEAD_DIM, num_q_heads=1, num_kv_heads=1, q_dtype=cutlass.BFloat16, k_dtype=cutlass.BFloat16, v_dtype=cutlass.BFloat16, o_dtype=cutlass.BFloat16, acc_dtype=cutlass.Float32, epilogue_dtype=cutlass.Float32, use_2cta_instrs=False, ) print(f'n={n}: Compiling reference FMHA...', flush=True) try: result = kernel.run(q, k, v, c, stream) torch.cuda.synchronize() out = c[0, 0].float() cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) ).item() n_tiles = n // 128 print(f'Reference FMHA n={n} ({n_tiles} tiles): cos {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') if cos < 0.99: print(f' out[0,:4]={out[0,:4].tolist()}') print(f' ref[0,:4]={ref[0,:4].tolist()}') except Exception as e: import traceback print(f'Reference FMHA n={n}: FAILED') traceback.print_exc() if __name__ == '__main__': test_reference()