From 4fc264e0340a30552ceb7eefdfbfa79865c28500 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 21:12:49 +0000 Subject: [PATCH] Test reference FMHA with proper API --- tests/test_reference_fmha.py | 80 ++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 45 deletions(-) diff --git a/tests/test_reference_fmha.py b/tests/test_reference_fmha.py index 0d2ff910..c4f3da3c 100644 --- a/tests/test_reference_fmha.py +++ b/tests/test_reference_fmha.py @@ -7,49 +7,43 @@ import torch import math import cutlass import cutlass.cute as cute +import cutlass.torch as ct import cuda.bindings.driver as cuda -from cute.blackwell.kernel.attention.fmha.fmha import FMHA, FusedMask, FusedMaskScale, FMHA_OperandMajorMode +from cute.blackwell.kernel.attention.fmha.fmha import ( + BlackwellFusedMultiHeadAttentionForward, + FusedMask, FusedMaskScale, FMHA_OperandMajorMode +) +from cutlass.utils import LayoutEnum -def test_reference_fmha(): +def test_reference(): HEAD_DIM = 64 for n in [128, 256, 512]: torch.manual_seed(42) m = 128 - - q = torch.randn(m, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') - k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') - v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') - c = torch.zeros(m, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda') - - # Reference - qf = q[:, :, 0].float() - kf = k[:, :, 0].float() + batch = 1 + + q = torch.randn(batch, 1, m, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + k = torch.randn(batch, 1, n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + v = torch.randn(batch, 1, n, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + c = torch.zeros(batch, 1, m, HEAD_DIM, dtype=torch.bfloat16, device='cuda') + + # Reference: PyTorch softmax attention + qf = q[0, 0].float() + kf = k[0, 0].float() + vf = v[0, 0].float() scale = 1.0 / math.sqrt(HEAD_DIM) attn = qf @ kf.T * scale attn = torch.softmax(attn, dim=-1) - ref = attn @ v.float() - - # CUTLASS reference FMHA - from cutlass.utils import LayoutEnum - import cutlass.torch as ct - - q_tensor = q - k_tensor = k - v_tensor = v.unsqueeze(-1) # Add batch dim - c_tensor = c - - q_maj = LayoutEnum.ROW_MAJOR # K major - k_maj = LayoutEnum.ROW_MAJOR - v_maj = LayoutEnum.ROW_MAJOR - o_maj = LayoutEnum.ROW_MAJOR - - # Build the FMHA kernel - kernel = FMHA( - q_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(q_maj), - k_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(k_maj), - v_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(v_maj), - o_major_mode=FMHA_OperandMajorMode.from_LayoutEnum(o_maj), + ref = attn @ vf + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = BlackwellFusedMultiHeadAttentionForward( + q_major_mode=FMHA_OperandMajorMode.K, + k_major_mode=FMHA_OperandMajorMode.K, + v_major_mode=FMHA_OperandMajorMode.MN, + o_major_mode=FMHA_OperandMajorMode.K, q_head_dim=HEAD_DIM, kv_head_dim=HEAD_DIM, num_q_heads=1, @@ -62,19 +56,13 @@ def test_reference_fmha(): epilogue_dtype=cutlass.Float32, use_2cta_instrs=False, ) - - # Set dimensions - kernel.set_dims(m, n) - - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - + print(f'n={n}: Compiling reference FMHA...', flush=True) try: - compiled = cute.compile(kernel, q_tensor, k_tensor, v_tensor, c_tensor, stream) - compiled(q_tensor, k_tensor, v_tensor, c_tensor, stream) + result = kernel.run(q, k, v, c, stream) torch.cuda.synchronize() - - out = c[:, :, 0].float() + + out = c[0, 0].float() cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) ).item() @@ -84,7 +72,9 @@ def test_reference_fmha(): print(f' out[0,:4]={out[0,:4].tolist()}') print(f' ref[0,:4]={ref[0,:4].tolist()}') except Exception as e: - print(f'Reference FMHA n={n}: FAILED with error: {e}') + import traceback + print(f'Reference FMHA n={n}: FAILED') + traceback.print_exc() if __name__ == '__main__': - test_reference_fmha() + test_reference()