diff --git a/tests/unit/test_fmha_sm100.py b/tests/unit/test_fmha_sm100.py index 07e5e8f7..7954b979 100644 --- a/tests/unit/test_fmha_sm100.py +++ b/tests/unit/test_fmha_sm100.py @@ -1,119 +1,85 @@ """ -Test: Raw CUDA FMHA decode kernel (fmha_sm100.cuh). +Test: Compile FMHA SM100 kernel with nvcc directly. -Phase 1: Verify kernel compiles and produces correct output at hd=64. -Uses torch.utils.cpp_extension to JIT-compile the CUDA kernel. +Step 1: Try to compile the .cuh to check for C++ errors. +Step 2: If that works, try torch.utils.cpp_extension JIT. """ -import torch -import math -import os +import subprocess import sys +import os -def compile_and_test(): +def get_repo_root(): + """Find repo root from this file's location.""" + d = os.path.dirname(os.path.abspath(__file__)) + # Go up until we find dsv4/ directory + while d != '/': + if os.path.exists(os.path.join(d, 'dsv4')): + return d + d = os.path.dirname(d) + return None + +REPO = get_repo_root() +print(f"Repo root: {REPO}") + +CUTLASS = "/root/dsv4-nvfp4-workspace/cutlass" + +# Step 1: Try nvcc compile (just syntax check) +print("\n" + "=" * 60) +print("Step 1: nvcc syntax check") +print("=" * 60) + +nvcc_cmd = [ + "nvcc", + "--std=c++17", + "-gencode=arch=compute_100a,code=sm_100a", + f"-I{CUTLASS}/include", + f"-I{REPO}", + "-DCUTE_ARCH_TCGEN05_MMA_ENABLED", + "-DCUTE_ARCH_TCGEN05_TMEM_ENABLED", + "-DCUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED", + "-c", + f"{REPO}/dsv4/kernels/attention/fmha_sm100.cuh", + "-o", "/tmp/fmha_sm100_test.o", + "--ptxas-options=-v", +] + +print(f"nvcc command: {' '.join(nvcc_cmd)}") +result = subprocess.run(nvcc_cmd, capture_output=True, text=True, timeout=120) +print(f"Exit code: {result.returncode}") +if result.stdout: + print(f"STDOUT:\n{result.stdout[-1000:]}") +if result.stderr: + print(f"STDERR:\n{result.stderr[-2000:]}") + +if result.returncode != 0: + print("\n❌ nvcc compilation FAILED — fix errors above before proceeding") + sys.exit(1) + +print("\n✅ nvcc compilation PASSED!") + +# Step 2: JIT compile with torch +print("\n" + "=" * 60) +print("Step 2: torch.utils.cpp_extension JIT") +print("=" * 60) + +try: from torch.utils.cpp_extension import load - # Kernel source paths - kernel_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - src = os.path.join(kernel_dir, "dsv4/kernels/attention/fmha_sm100.cpp") - - print(f"Compiling: {src}") - print(f"Kernel dir: {kernel_dir}") - - # CUDA 13.0 + SM100 flags - nvcc_flags = [ - "-gencode=arch=compute_100a,code=sm_100a", - "-DCUTE_ARCH_TCGEN05_MMA_ENABLED", - "-DCUTE_ARCH_TCGEN05_TMEM_ENABLED", - "-DCUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED", - "--ptxas-options=-v", - f"-I{kernel_dir}", - ] - - # CUTLASS include paths - cutlass_dir = "/root/dsv4-nvfp4-workspace/cutlass" - if os.path.exists(cutlass_dir): - nvcc_flags.extend([ - f"-I{cutlass_dir}/include", - f"-I{cutlass_dir}/examples/common", - ]) - - try: - module = load( - name="fmha_sm100", - sources=[src], - extra_cuda_cflags=nvcc_flags, - verbose=True, - ) - print("\n✅ Kernel compiled successfully!") - return True - except Exception as e: - print(f"\n❌ Compilation failed: {e}") - return False - - -def test_correctness(): - """Test FMHA output against PyTorch reference.""" - from torch.utils.cpp_extension import load - - kernel_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - src = os.path.join(kernel_dir, "dsv4/kernels/attention/fmha_sm100.cpp") - cutlass_dir = "/root/dsv4-nvfp4-workspace/cutlass" - module = load( name="fmha_sm100", - sources=[src], + sources=[f"{REPO}/dsv4/kernels/attention/fmha_sm100.cpp"], extra_cuda_cflags=[ "-gencode=arch=compute_100a,code=sm_100a", - f"-I{kernel_dir}", - f"-I{cutlass_dir}/include", + f"-I{REPO}", + f"-I{CUTLASS}/include", ], + verbose=True, ) + print("\n✅ JIT compilation PASSED!") + print(f"Module: {module}") + print(f"fmha_decode: {module.fmha_decode}") +except Exception as e: + print(f"\n❌ JIT compilation FAILED: {e}") + sys.exit(1) - # Test: hd=64, s_k=128, T=1, B=1, H=1 - B, H, T, D = 1, 1, 1, 64 - s_k = 128 - scale = 1.0 / math.sqrt(D) - - torch.manual_seed(42) - q = torch.randn(B, H, T, D, dtype=torch.bfloat16, device='cuda') - k = torch.randn(B, s_k, D, dtype=torch.bfloat16, device='cuda') - v = torch.randn(B, D, s_k, dtype=torch.bfloat16, device='cuda') - - # PyTorch reference - q_f = q.float() - k_f = k.float() - v_f = v.float() - # S = Q @ K^T * scale - s = torch.matmul(q_f, k_f.transpose(-2, -1)) * scale - # Softmax - p = torch.softmax(s, dim=-1) - # O = P @ V^T (V is stored as (D, s_k), so V^T = V for matmul) - o_ref = torch.matmul(p, v_f.transpose(-2, -1)) - o_ref_bf16 = o_ref.bfloat16() - - # Kernel - o, lse = module.fmha_decode(q, k, v, scale, 0, 0, False, None) - - # Compare - cos = torch.nn.functional.cosine_similarity( - o.float().flatten().unsqueeze(0), - o_ref_bf16.float().flatten().unsqueeze(0) - ).item() - - print(f"hd={D}, s_k={s_k}: cos {cos:.6f} {'PASS' if cos > 0.999 else 'FAIL'}") - print(f" out[:4] = {o[0,0,0,:4].float().tolist()}") - print(f" ref[:4] = {o_ref_bf16[0,0,0,:4].float().tolist()}") - - return cos > 0.999 - - -if __name__ == "__main__": - print("=" * 60) - print("FMHA SM100 Raw CUDA — Compilation Test") - print("=" * 60) - - if compile_and_test(): - print("\n" + "=" * 60) - print("FMHA SM100 Raw CUDA — Correctness Test") - print("=" * 60) - test_correctness() +print("\n✅ ALL COMPILATION TESTS PASSED!")