74 lines
2.5 KiB
Python
74 lines
2.5 KiB
Python
#!/usr/bin/env python3
|
|
"""B2 minimal debug: test if the FP8 indexer kernel even completes.
|
|
|
|
Start with n_comp=1 (trivial case) and work up.
|
|
Also test: does the kernel return at all? Or does it hang?
|
|
"""
|
|
import sys
|
|
import math
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import time
|
|
|
|
|
|
def quantize_fp8_e4m3(x_fp32):
|
|
amax = x_fp32.abs().amax(dim=-1, keepdim=True).clamp(min=1e-12)
|
|
scale = amax / 448.0
|
|
fp8 = (x_fp32 / scale).clamp(-448, 448).to(torch.float8_e4m3fn)
|
|
return fp8.view(torch.uint8), scale.squeeze(-1)
|
|
|
|
|
|
def main():
|
|
N_IH = 64; IHD = 128; TOP_K = 1
|
|
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
mod = get_cuda_module("indexer_fp8_score_topk", ["indexer_fp8_score_topk.cu"],
|
|
extra_cuda_cflags=[
|
|
"-gencode=arch=compute_100a,code=sm_100a",
|
|
"-O3", "--use_fast_math", "--expt-relaxed-constexpr",
|
|
])
|
|
|
|
# Test with increasing n_comp
|
|
for n_comp in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]:
|
|
print(f"\n--- n_comp={n_comp} ---", flush=True)
|
|
torch.manual_seed(42)
|
|
q_idx = torch.randn(N_IH, IHD, dtype=torch.bfloat16).cuda() * 0.5
|
|
k_fp32 = torch.randn(n_comp, IHD, dtype=torch.float32) * 0.5
|
|
w_h = torch.randn(N_IH, dtype=torch.bfloat16).cuda() * 0.3
|
|
k_fp8, k_scale = quantize_fp8_e4m3(k_fp32)
|
|
k_fp8 = k_fp8.cuda(); k_scale = k_scale.cuda()
|
|
|
|
tk = min(TOP_K, n_comp)
|
|
topk_indices = torch.empty(tk, dtype=torch.int32, device='cuda')
|
|
|
|
# Add a CUDA event to time the kernel
|
|
start = torch.cuda.Event(enable_timing=True)
|
|
end = torch.cuda.Event(enable_timing=True)
|
|
start.record()
|
|
try:
|
|
mod.indexer_fp8_score_topk(
|
|
q_idx, k_fp8, k_scale, w_h, topk_indices,
|
|
N_IH, IHD, tk)
|
|
end.record()
|
|
torch.cuda.synchronize()
|
|
ms = start.elapsed_time(end)
|
|
print(f" OK: {ms:.2f}ms indices={topk_indices.cpu().tolist()}", flush=True)
|
|
except Exception as e:
|
|
torch.cuda.synchronize()
|
|
print(f" FAILED: {e}", flush=True)
|
|
break
|
|
|
|
# Check for CUDA errors
|
|
err = torch.cuda.current_device()
|
|
if n_comp >= 128:
|
|
# After a certain size, try timing it
|
|
# Also check if the GPU is responsive
|
|
print(f" GPU responsive: yes", flush=True)
|
|
|
|
print("\nDone.")
|
|
sys.exit(0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|