#!/usr/bin/env python3 import sys, os os.chdir('/root/dsv4-nvfp4-workspace/kernel') sys.path.insert(0, '.') import torch from torch.utils.cpp_extension import load def pf(ok): return 'PASS' if ok else 'FAIL' print(f'CUDA: {torch.cuda.is_available()}, Device: {torch.cuda.get_device_name(0)}', flush=True) print(f'GPU mem free: {torch.cuda.mem_get_info()[0]//1024//1024} MB', flush=True) # 1. Hash Router print('\n=== Hash Router ===', flush=True) try: hr = load(name='hr2', sources=['dsv4/kernels/cuda/hash_router.cu'], extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False) for N in [1, 4, 64, 128, 512]: vocab, k, E = 128000, 6, 256 lut = torch.randint(0, E, (vocab, k), dtype=torch.int32, device='cuda') tids = torch.randint(0, vocab, (N,), dtype=torch.int32, device='cuda') ow = torch.empty(N, k, dtype=torch.float32, device='cuda') oi = torch.empty(N, k, dtype=torch.int32, device='cuda') hr.hash_router(tids, lut, k, ow, oi) torch.cuda.synchronize() exp_ids = lut[tids] exp_w = torch.full((N, k), 1.0/k, dtype=torch.float32, device='cuda') ids_ok = (oi == exp_ids).all().item() w_ok = torch.allclose(ow, exp_w, atol=1e-7, rtol=1e-7) ok = ids_ok and w_ok print(f' N={N:4d}: IDs={ids_ok} W={w_ok} {pf(ok)}', flush=True) del lut, tids, ow, oi, exp_ids, exp_w print('Hash Router: ALL PASS', flush=True) except Exception as e: import traceback; traceback.print_exc() torch.cuda.empty_cache() # 2. Top-k Select print('\n=== Top-k Select ===', flush=True) try: tk = load(name='tk2', sources=['dsv4/kernels/cuda/topk_select.cu'], extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False) for N, E in [(1,256), (4,256), (64,256), (64,384), (128,256), (512,256)]: k = 6 scores = torch.randn(N, E, dtype=torch.float32, device='cuda') ov, oidx = tk.topk_select(scores, k) torch.cuda.synchronize() exp = scores.topk(k, dim=-1) ids_ok = (oidx == exp.indices).all().item() vals_ok = torch.allclose(ov, exp.values, atol=1e-6, rtol=1e-6) ok = ids_ok and vals_ok print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True) del scores, ov, oidx, exp print('Top-k Select: ALL PASS', flush=True) except Exception as e: import traceback; traceback.print_exc() torch.cuda.empty_cache() # 3. Activation + Top-k print('\n=== Activation + Top-k ===', flush=True) try: atk = load(name='atk2', sources=['dsv4/kernels/cuda/activation_topk.cu'], extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False) for N, E in [(1,256), (4,256), (64,256), (64,384)]: k = 6 logits = torch.randn(N, E, dtype=torch.float32, device='cuda') bias = torch.randn(E, dtype=torch.float32, device='cuda') scaling = 2.5 out_w = torch.empty(N, k, dtype=torch.float32, device='cuda') out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda') atk.activation_topk(logits, bias, k, scaling, out_w, out_ids) torch.cuda.synchronize() # Oracle sp = torch.log1p(logits.exp()) + torch.clamp(logits, min=0) act = sp.sqrt() score = act + bias exp_topk = score.topk(k, dim=-1) exp_ids = exp_topk.indices exp_w = torch.gather(act, 1, exp_ids) exp_w = exp_w / exp_w.sum(dim=-1, keepdim=True) * scaling ids_ok = (out_ids == exp_ids).all().item() vals_ok = torch.allclose(out_w, exp_w, atol=1e-5, rtol=1e-5) ok = ids_ok and vals_ok print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True) if not ids_ok: print(f' ID mismatches: {(out_ids != exp_ids).sum().item()}/{out_ids.numel()}', flush=True) if not vals_ok: print(f' Max diff: {(out_w - exp_w).abs().max().item():.2e}', flush=True) del logits, bias, out_w, out_ids, sp, act, score, exp_topk, exp_ids, exp_w print('Activation+Topk: ALL PASS', flush=True) except Exception as e: import traceback; traceback.print_exc() print('\n=== DONE ===', flush=True)