test: nvcc compilation test for FMHA SM100 kernel
This commit is contained in:
@@ -1,119 +1,85 @@
|
||||
"""
|
||||
Test: Raw CUDA FMHA decode kernel (fmha_sm100.cuh).
|
||||
Test: Compile FMHA SM100 kernel with nvcc directly.
|
||||
|
||||
Phase 1: Verify kernel compiles and produces correct output at hd=64.
|
||||
Uses torch.utils.cpp_extension to JIT-compile the CUDA kernel.
|
||||
Step 1: Try to compile the .cuh to check for C++ errors.
|
||||
Step 2: If that works, try torch.utils.cpp_extension JIT.
|
||||
"""
|
||||
import torch
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
|
||||
def compile_and_test():
|
||||
def get_repo_root():
|
||||
"""Find repo root from this file's location."""
|
||||
d = os.path.dirname(os.path.abspath(__file__))
|
||||
# Go up until we find dsv4/ directory
|
||||
while d != '/':
|
||||
if os.path.exists(os.path.join(d, 'dsv4')):
|
||||
return d
|
||||
d = os.path.dirname(d)
|
||||
return None
|
||||
|
||||
REPO = get_repo_root()
|
||||
print(f"Repo root: {REPO}")
|
||||
|
||||
CUTLASS = "/root/dsv4-nvfp4-workspace/cutlass"
|
||||
|
||||
# Step 1: Try nvcc compile (just syntax check)
|
||||
print("\n" + "=" * 60)
|
||||
print("Step 1: nvcc syntax check")
|
||||
print("=" * 60)
|
||||
|
||||
nvcc_cmd = [
|
||||
"nvcc",
|
||||
"--std=c++17",
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
f"-I{CUTLASS}/include",
|
||||
f"-I{REPO}",
|
||||
"-DCUTE_ARCH_TCGEN05_MMA_ENABLED",
|
||||
"-DCUTE_ARCH_TCGEN05_TMEM_ENABLED",
|
||||
"-DCUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED",
|
||||
"-c",
|
||||
f"{REPO}/dsv4/kernels/attention/fmha_sm100.cuh",
|
||||
"-o", "/tmp/fmha_sm100_test.o",
|
||||
"--ptxas-options=-v",
|
||||
]
|
||||
|
||||
print(f"nvcc command: {' '.join(nvcc_cmd)}")
|
||||
result = subprocess.run(nvcc_cmd, capture_output=True, text=True, timeout=120)
|
||||
print(f"Exit code: {result.returncode}")
|
||||
if result.stdout:
|
||||
print(f"STDOUT:\n{result.stdout[-1000:]}")
|
||||
if result.stderr:
|
||||
print(f"STDERR:\n{result.stderr[-2000:]}")
|
||||
|
||||
if result.returncode != 0:
|
||||
print("\n❌ nvcc compilation FAILED — fix errors above before proceeding")
|
||||
sys.exit(1)
|
||||
|
||||
print("\n✅ nvcc compilation PASSED!")
|
||||
|
||||
# Step 2: JIT compile with torch
|
||||
print("\n" + "=" * 60)
|
||||
print("Step 2: torch.utils.cpp_extension JIT")
|
||||
print("=" * 60)
|
||||
|
||||
try:
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
# Kernel source paths
|
||||
kernel_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
src = os.path.join(kernel_dir, "dsv4/kernels/attention/fmha_sm100.cpp")
|
||||
|
||||
print(f"Compiling: {src}")
|
||||
print(f"Kernel dir: {kernel_dir}")
|
||||
|
||||
# CUDA 13.0 + SM100 flags
|
||||
nvcc_flags = [
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
"-DCUTE_ARCH_TCGEN05_MMA_ENABLED",
|
||||
"-DCUTE_ARCH_TCGEN05_TMEM_ENABLED",
|
||||
"-DCUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED",
|
||||
"--ptxas-options=-v",
|
||||
f"-I{kernel_dir}",
|
||||
]
|
||||
|
||||
# CUTLASS include paths
|
||||
cutlass_dir = "/root/dsv4-nvfp4-workspace/cutlass"
|
||||
if os.path.exists(cutlass_dir):
|
||||
nvcc_flags.extend([
|
||||
f"-I{cutlass_dir}/include",
|
||||
f"-I{cutlass_dir}/examples/common",
|
||||
])
|
||||
|
||||
try:
|
||||
module = load(
|
||||
name="fmha_sm100",
|
||||
sources=[src],
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
verbose=True,
|
||||
)
|
||||
print("\n✅ Kernel compiled successfully!")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"\n❌ Compilation failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_correctness():
|
||||
"""Test FMHA output against PyTorch reference."""
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
kernel_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
src = os.path.join(kernel_dir, "dsv4/kernels/attention/fmha_sm100.cpp")
|
||||
cutlass_dir = "/root/dsv4-nvfp4-workspace/cutlass"
|
||||
|
||||
module = load(
|
||||
name="fmha_sm100",
|
||||
sources=[src],
|
||||
sources=[f"{REPO}/dsv4/kernels/attention/fmha_sm100.cpp"],
|
||||
extra_cuda_cflags=[
|
||||
"-gencode=arch=compute_100a,code=sm_100a",
|
||||
f"-I{kernel_dir}",
|
||||
f"-I{cutlass_dir}/include",
|
||||
f"-I{REPO}",
|
||||
f"-I{CUTLASS}/include",
|
||||
],
|
||||
verbose=True,
|
||||
)
|
||||
print("\n✅ JIT compilation PASSED!")
|
||||
print(f"Module: {module}")
|
||||
print(f"fmha_decode: {module.fmha_decode}")
|
||||
except Exception as e:
|
||||
print(f"\n❌ JIT compilation FAILED: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Test: hd=64, s_k=128, T=1, B=1, H=1
|
||||
B, H, T, D = 1, 1, 1, 64
|
||||
s_k = 128
|
||||
scale = 1.0 / math.sqrt(D)
|
||||
|
||||
torch.manual_seed(42)
|
||||
q = torch.randn(B, H, T, D, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(B, s_k, D, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(B, D, s_k, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# PyTorch reference
|
||||
q_f = q.float()
|
||||
k_f = k.float()
|
||||
v_f = v.float()
|
||||
# S = Q @ K^T * scale
|
||||
s = torch.matmul(q_f, k_f.transpose(-2, -1)) * scale
|
||||
# Softmax
|
||||
p = torch.softmax(s, dim=-1)
|
||||
# O = P @ V^T (V is stored as (D, s_k), so V^T = V for matmul)
|
||||
o_ref = torch.matmul(p, v_f.transpose(-2, -1))
|
||||
o_ref_bf16 = o_ref.bfloat16()
|
||||
|
||||
# Kernel
|
||||
o, lse = module.fmha_decode(q, k, v, scale, 0, 0, False, None)
|
||||
|
||||
# Compare
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
o.float().flatten().unsqueeze(0),
|
||||
o_ref_bf16.float().flatten().unsqueeze(0)
|
||||
).item()
|
||||
|
||||
print(f"hd={D}, s_k={s_k}: cos {cos:.6f} {'PASS' if cos > 0.999 else 'FAIL'}")
|
||||
print(f" out[:4] = {o[0,0,0,:4].float().tolist()}")
|
||||
print(f" ref[:4] = {o_ref_bf16[0,0,0,:4].float().tolist()}")
|
||||
|
||||
return cos > 0.999
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("FMHA SM100 Raw CUDA — Compilation Test")
|
||||
print("=" * 60)
|
||||
|
||||
if compile_and_test():
|
||||
print("\n" + "=" * 60)
|
||||
print("FMHA SM100 Raw CUDA — Correctness Test")
|
||||
print("=" * 60)
|
||||
test_correctness()
|
||||
print("\n✅ ALL COMPILATION TESTS PASSED!")
|
||||
|
||||
Reference in New Issue
Block a user