From a800a83d5c2f99a919fe563be530ffa759bfcd4c Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:11:58 +0000 Subject: [PATCH] Test: CUTLASS reference FMHA on B200 multi-tile --- tests/test_reference_fmha.py | 90 ++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 tests/test_reference_fmha.py diff --git a/tests/test_reference_fmha.py b/tests/test_reference_fmha.py new file mode 100644 index 00000000..0d2ff910 --- /dev/null +++ b/tests/test_reference_fmha.py @@ -0,0 +1,90 @@ +"""Test the CUTLASS reference Blackwell FMHA on the B200. +Does it actually work multi-tile?""" +import sys +sys.path.insert(0, '/root/cutlass/examples/python/CuTeDSL') + +import torch +import math +import cutlass +import cutlass.cute as cute +import cuda.bindings.driver as cuda + +from cute.blackwell.kernel.attention.fmha.fmha import FMHA, FusedMask, FusedMaskScale, FMHA_OperandMajorMode + +def test_reference_fmha(): + HEAD_DIM = 64 + for n in [128, 256, 512]: + torch.manual_seed(42) + m = 128 + + q = torch.randn(m, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(m, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') + + # Reference + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(HEAD_DIM) + attn = qf @ kf.T * scale + attn = torch.softmax(attn, dim=-1) + ref = attn @ v.float() + + # CUTLASS reference FMHA + from cutlass.utils import LayoutEnum + import cutlass.torch as ct + + q_tensor = q + k_tensor = k + v_tensor = v.unsqueeze(-1) # Add batch dim + c_tensor = c + + q_maj = LayoutEnum.ROW_MAJOR # K major + k_maj = LayoutEnum.ROW_MAJOR + v_maj = LayoutEnum.ROW_MAJOR + o_maj = LayoutEnum.ROW_MAJOR + + # Build the FMHA kernel + kernel = FMHA( + q_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(q_maj), + k_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(k_maj), + v_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(v_maj), + o_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(o_maj), + q_head_dim=HEAD_DIM, + kv_head_dim=HEAD_DIM, + num_q_heads=1, + num_kv_heads=1, + q_dtype=cutlass.BFloat16, + k_dtype=cutlass.BFloat16, + v_dtype=cutlass.BFloat16, + o_dtype=cutlass.BFloat16, + acc_dtype=cutlass.Float32, + epilogue_dtype=cutlass.Float32, + use_2cta_instrs=False, + ) + + # Set dimensions + kernel.set_dims(m, n) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + print(f'n={n}: Compiling reference FMHA...', flush=True) + try: + compiled = cute.compile(kernel, q_tensor, k_tensor, v_tensor, c_tensor, stream) + compiled(q_tensor, k_tensor, v_tensor, c_tensor, stream) + torch.cuda.synchronize() + + out = c[:, :, 0].float() + cos = torch.nn.functional.cosine_similarity( + out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) + ).item() + n_tiles = n // 128 + print(f'Reference FMHA n={n} ({n_tiles} tiles): cos {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}') + if cos < 0.99: + print(f' out[0,:4]={out[0,:4].tolist()}') + print(f' ref[0,:4]={ref[0,:4].tolist()}') + except Exception as e: + print(f'Reference FMHA n={n}: FAILED with error: {e}') + +if __name__ == '__main__': + test_reference_fmha()