test: nvcc compilation test for FMHA SM100 kernel

This commit is contained in:
2026-05-28 05:05:31 +00:00
parent 230c350c77
commit fac7275f2b

View File

@@ -1,119 +1,85 @@
"""
Test: Raw CUDA FMHA decode kernel (fmha_sm100.cuh).
Test: Compile FMHA SM100 kernel with nvcc directly.
Phase 1: Verify kernel compiles and produces correct output at hd=64.
Uses torch.utils.cpp_extension to JIT-compile the CUDA kernel.
Step 1: Try to compile the .cuh to check for C++ errors.
Step 2: If that works, try torch.utils.cpp_extension JIT.
"""
import torch
import math
import os
import subprocess
import sys
import os
def compile_and_test():
def get_repo_root():
"""Find repo root from this file's location."""
d = os.path.dirname(os.path.abspath(__file__))
# Go up until we find dsv4/ directory
while d != '/':
if os.path.exists(os.path.join(d, 'dsv4')):
return d
d = os.path.dirname(d)
return None
REPO = get_repo_root()
print(f"Repo root: {REPO}")
CUTLASS = "/root/dsv4-nvfp4-workspace/cutlass"
# Step 1: Try nvcc compile (just syntax check)
print("\n" + "=" * 60)
print("Step 1: nvcc syntax check")
print("=" * 60)
nvcc_cmd = [
"nvcc",
"--std=c++17",
"-gencode=arch=compute_100a,code=sm_100a",
f"-I{CUTLASS}/include",
f"-I{REPO}",
"-DCUTE_ARCH_TCGEN05_MMA_ENABLED",
"-DCUTE_ARCH_TCGEN05_TMEM_ENABLED",
"-DCUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED",
"-c",
f"{REPO}/dsv4/kernels/attention/fmha_sm100.cuh",
"-o", "/tmp/fmha_sm100_test.o",
"--ptxas-options=-v",
]
print(f"nvcc command: {' '.join(nvcc_cmd)}")
result = subprocess.run(nvcc_cmd, capture_output=True, text=True, timeout=120)
print(f"Exit code: {result.returncode}")
if result.stdout:
print(f"STDOUT:\n{result.stdout[-1000:]}")
if result.stderr:
print(f"STDERR:\n{result.stderr[-2000:]}")
if result.returncode != 0:
print("\n❌ nvcc compilation FAILED — fix errors above before proceeding")
sys.exit(1)
print("\n✅ nvcc compilation PASSED!")
# Step 2: JIT compile with torch
print("\n" + "=" * 60)
print("Step 2: torch.utils.cpp_extension JIT")
print("=" * 60)
try:
from torch.utils.cpp_extension import load
# Kernel source paths
kernel_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
src = os.path.join(kernel_dir, "dsv4/kernels/attention/fmha_sm100.cpp")
print(f"Compiling: {src}")
print(f"Kernel dir: {kernel_dir}")
# CUDA 13.0 + SM100 flags
nvcc_flags = [
"-gencode=arch=compute_100a,code=sm_100a",
"-DCUTE_ARCH_TCGEN05_MMA_ENABLED",
"-DCUTE_ARCH_TCGEN05_TMEM_ENABLED",
"-DCUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED",
"--ptxas-options=-v",
f"-I{kernel_dir}",
]
# CUTLASS include paths
cutlass_dir = "/root/dsv4-nvfp4-workspace/cutlass"
if os.path.exists(cutlass_dir):
nvcc_flags.extend([
f"-I{cutlass_dir}/include",
f"-I{cutlass_dir}/examples/common",
])
try:
module = load(
name="fmha_sm100",
sources=[src],
extra_cuda_cflags=nvcc_flags,
verbose=True,
)
print("\n✅ Kernel compiled successfully!")
return True
except Exception as e:
print(f"\n❌ Compilation failed: {e}")
return False
def test_correctness():
"""Test FMHA output against PyTorch reference."""
from torch.utils.cpp_extension import load
kernel_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
src = os.path.join(kernel_dir, "dsv4/kernels/attention/fmha_sm100.cpp")
cutlass_dir = "/root/dsv4-nvfp4-workspace/cutlass"
module = load(
name="fmha_sm100",
sources=[src],
sources=[f"{REPO}/dsv4/kernels/attention/fmha_sm100.cpp"],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
f"-I{kernel_dir}",
f"-I{cutlass_dir}/include",
f"-I{REPO}",
f"-I{CUTLASS}/include",
],
verbose=True,
)
print("\n✅ JIT compilation PASSED!")
print(f"Module: {module}")
print(f"fmha_decode: {module.fmha_decode}")
except Exception as e:
print(f"\n❌ JIT compilation FAILED: {e}")
sys.exit(1)
# Test: hd=64, s_k=128, T=1, B=1, H=1
B, H, T, D = 1, 1, 1, 64
s_k = 128
scale = 1.0 / math.sqrt(D)
torch.manual_seed(42)
q = torch.randn(B, H, T, D, dtype=torch.bfloat16, device='cuda')
k = torch.randn(B, s_k, D, dtype=torch.bfloat16, device='cuda')
v = torch.randn(B, D, s_k, dtype=torch.bfloat16, device='cuda')
# PyTorch reference
q_f = q.float()
k_f = k.float()
v_f = v.float()
# S = Q @ K^T * scale
s = torch.matmul(q_f, k_f.transpose(-2, -1)) * scale
# Softmax
p = torch.softmax(s, dim=-1)
# O = P @ V^T (V is stored as (D, s_k), so V^T = V for matmul)
o_ref = torch.matmul(p, v_f.transpose(-2, -1))
o_ref_bf16 = o_ref.bfloat16()
# Kernel
o, lse = module.fmha_decode(q, k, v, scale, 0, 0, False, None)
# Compare
cos = torch.nn.functional.cosine_similarity(
o.float().flatten().unsqueeze(0),
o_ref_bf16.float().flatten().unsqueeze(0)
).item()
print(f"hd={D}, s_k={s_k}: cos {cos:.6f} {'PASS' if cos > 0.999 else 'FAIL'}")
print(f" out[:4] = {o[0,0,0,:4].float().tolist()}")
print(f" ref[:4] = {o_ref_bf16[0,0,0,:4].float().tolist()}")
return cos > 0.999
if __name__ == "__main__":
print("=" * 60)
print("FMHA SM100 Raw CUDA — Compilation Test")
print("=" * 60)
if compile_and_test():
print("\n" + "=" * 60)
print("FMHA SM100 Raw CUDA — Correctness Test")
print("=" * 60)
test_correctness()
print("\n✅ ALL COMPILATION TESTS PASSED!")