Files
nvfp4-megamoe-kernel/tests/unit/test_b2_minimal.py

74 lines
2.5 KiB
Python

#!/usr/bin/env python3
"""B2 minimal debug: test if the FP8 indexer kernel even completes.
Start with n_comp=1 (trivial case) and work up.
Also test: does the kernel return at all? Or does it hang?
"""
import sys
import math
import torch
import torch.nn.functional as F
import time
def quantize_fp8_e4m3(x_fp32):
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
scale = amax / 448.0
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
return fp8.view(torch.uint8), scale.squeeze(-1)
def main():
N_IH = 64; IHD = 128; TOP_K = 1
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"],
extra_cuda_cflags=[
"-gencode=arch=compute_100a,code=sm_100a",
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
])
# Test with increasing n_comp
for n_comp in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
print(f"\n--- n_comp={n_comp} ---", flush=True)
torch.manual_seed(42)
q_idx = torch.randn(N_IH, IHD, dtype=torch.bfloat16).cuda() * 0.5
k_fp32 = torch.randn(n_comp, IHD, dtype=torch.float32) * 0.5
w_h = torch.randn(N_IH, dtype=torch.bfloat16).cuda() * 0.3
k_fp8, k_scale = quantize_fp8_e4m3(k_fp32)
k_fp8 = k_fp8.cuda(); k_scale = k_scale.cuda()
tk = min(TOP_K, n_comp)
topk_indices = torch.empty(tk, dtype=torch.int32, device='cuda')
# Add a CUDA event to time the kernel
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
try:
mod.indexer_fp8_score_topk(
q_idx, k_fp8, k_scale, w_h, topk_indices,
N_IH, IHD, tk)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end)
print(f" OK: {ms:.2f}ms indices={topk_indices.cpu().tolist()}", flush=True)
except Exception as e:
torch.cuda.synchronize()
print(f" FAILED: {e}", flush=True)
break
# Check for CUDA errors
err = torch.cuda.current_device()
if n_comp >= 128:
# After a certain size, try timing it
# Also check if the GPU is responsive
print(f" GPU responsive: yes", flush=True)
print("\nDone.")
sys.exit(0)
if __name__ == "__main__":
main()