Files
nvfp4-megamoe-kernel/run_router_tests.py
biondizzle 66a89859ed Layer dispatch: config, schedule, attention/FFN sub-blocks, TransformerLayer
DSV4Config: frozen dataclass with .flash() / .pro() classmethods.
All architectural constants (dims, heads, MoE params, mHC) in one place.

LayerSchedule: pure-data per-layer-index -> (attn_type, ffn_type, router_mode).
  Flash: SWA, SWA, CSA, HCA, CSA, HCA, ... (43 layers)
  Pro:   HCA, HCA, CSA, HCA, CSA, HCA, ... (61 layers)
  Both:  first 3 MoE layers = hash routing, rest = dense
  validate_schedule() enforces correctness at construction.

AttentionSubBlock: CSA / HCA / SWA variants.
  - Low-rank Q projection (q_down -> q_up)
  - KV down-projection (varies by attn type: 4h/2h/1h)
  - CSA: indexer_q_up + indexer_head_weights
  - Grouped output projection (wo_a + wo_b)
  - Kernel calls are imports (NotImplementedError until kernel lands)
  - No PyTorch fallback paths

FFNSubBlock: MoE + shared expert.
  - Router (hash/dense) mode from LayerSpec
  - Nvfp4MoE + Nvfp4SharedExpert

TransformerLayer: composition of mHC + norm + attention + FFN.
  - Two mHC wrappers (attn + ffn sub-blocks)
  - Two RMSNorm (one per sub-block)
  - Pure orchestration, no learned params on the layer itself

Tests: schedule construction + validation for both variants.
No forward tests yet (depends on FMHA kernel + KV cache).
2026-05-21 23:11:09 +00:00

98 lines
4.2 KiB
Python

#!/usr/bin/env python3
import sys, os
os.chdir('/root/dsv4-nvfp4-workspace/kernel')
sys.path.insert(0, '.')
import torch
from torch.utils.cpp_extension import load
def pf(ok):
return 'PASS' if ok else 'FAIL'
print(f'CUDA: {torch.cuda.is_available()}, Device: {torch.cuda.get_device_name(0)}', flush=True)
print(f'GPU mem free: {torch.cuda.mem_get_info()[0]//1024//1024} MB', flush=True)
# 1. Hash Router
print('\n=== Hash Router ===', flush=True)
try:
hr = load(name='hr2', sources=['dsv4/kernels/cuda/hash_router.cu'],
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
for N in [1, 4, 64, 128, 512]:
vocab, k, E = 128000, 6, 256
lut = torch.randint(0, E, (vocab, k), dtype=torch.int32, device='cuda')
tids = torch.randint(0, vocab, (N,), dtype=torch.int32, device='cuda')
ow = torch.empty(N, k, dtype=torch.float32, device='cuda')
oi = torch.empty(N, k, dtype=torch.int32, device='cuda')
hr.hash_router(tids, lut, k, ow, oi)
torch.cuda.synchronize()
exp_ids = lut[tids]
exp_w = torch.full((N, k), 1.0/k, dtype=torch.float32, device='cuda')
ids_ok = (oi == exp_ids).all().item()
w_ok = torch.allclose(ow, exp_w, atol=1e-7, rtol=1e-7)
ok = ids_ok and w_ok
print(f' N={N:4d}: IDs={ids_ok} W={w_ok} {pf(ok)}', flush=True)
del lut, tids, ow, oi, exp_ids, exp_w
print('Hash Router: ALL PASS', flush=True)
except Exception as e:
import traceback; traceback.print_exc()
torch.cuda.empty_cache()
# 2. Top-k Select
print('\n=== Top-k Select ===', flush=True)
try:
tk = load(name='tk2', sources=['dsv4/kernels/cuda/topk_select.cu'],
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
for N, E in [(1,256), (4,256), (64,256), (64,384), (128,256), (512,256)]:
k = 6
scores = torch.randn(N, E, dtype=torch.float32, device='cuda')
ov, oidx = tk.topk_select(scores, k)
torch.cuda.synchronize()
exp = scores.topk(k, dim=-1)
ids_ok = (oidx == exp.indices).all().item()
vals_ok = torch.allclose(ov, exp.values, atol=1e-6, rtol=1e-6)
ok = ids_ok and vals_ok
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
del scores, ov, oidx, exp
print('Top-k Select: ALL PASS', flush=True)
except Exception as e:
import traceback; traceback.print_exc()
torch.cuda.empty_cache()
# 3. Activation + Top-k
print('\n=== Activation + Top-k ===', flush=True)
try:
atk = load(name='atk2', sources=['dsv4/kernels/cuda/activation_topk.cu'],
extra_cuda_cflags=['-O3', '--generate-code=arch=compute_100a,code=[sm_100a]'], verbose=False)
for N, E in [(1,256), (4,256), (64,256), (64,384)]:
k = 6
logits = torch.randn(N, E, dtype=torch.float32, device='cuda')
bias = torch.randn(E, dtype=torch.float32, device='cuda')
scaling = 2.5
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
atk.activation_topk(logits, bias, k, scaling, out_w, out_ids)
torch.cuda.synchronize()
# Oracle
sp = torch.log1p(logits.exp()) + torch.clamp(logits, min=0)
act = sp.sqrt()
score = act + bias
exp_topk = score.topk(k, dim=-1)
exp_ids = exp_topk.indices
exp_w = torch.gather(act, 1, exp_ids)
exp_w = exp_w / exp_w.sum(dim=-1, keepdim=True) * scaling
ids_ok = (out_ids == exp_ids).all().item()
vals_ok = torch.allclose(out_w, exp_w, atol=1e-5, rtol=1e-5)
ok = ids_ok and vals_ok
print(f' N={N:4d} E={E}: IDs={ids_ok} V={vals_ok} {pf(ok)}', flush=True)
if not ids_ok:
print(f' ID mismatches: {(out_ids != exp_ids).sum().item()}/{out_ids.numel()}', flush=True)
if not vals_ok:
print(f' Max diff: {(out_w - exp_w).abs().max().item():.2e}', flush=True)
del logits, bias, out_w, out_ids, sp, act, score, exp_topk, exp_ids, exp_w
print('Activation+Topk: ALL PASS', flush=True)
except Exception as e:
import traceback; traceback.print_exc()
print('\n=== DONE ===', flush=True)