Files
nvfp4-megamoe-kernel/tests/archive/test_dense_router.py

28 lines
1.0 KiB
Python
Raw Normal View History

Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill Step 1: Hash router (hash_router.cu) - One thread per token, gather from [vocab_size, k] LUT - Uniform 1/k weights, FP32 output - 3 MB LUT fits in L2 for repeated decode calls Step 2: topk_select.cu — general top-k primitive - Per-thread register min-heap (k=6, compile-time unrolled) - Shared memory merge: thread 0 merges 64 partial heaps - Tie-breaking: lower index wins on equal scores - Reusable by CSA indexer Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm - Single kernel: all 6 steps of the router math, no intermediate buffers - Numerically stable softplus: max(x,0) + log1p(exp(-|x|)) - Per-thread heap with unbiased activation co-stored - Shared memory merge → sort descending → renormalize → store Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton) - BF16 GEMM with tcgen05.mma, FP32 accumulator - Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate) - Dispatch: N<=64 uses fused decode, N>64 uses prefill path Step 5: dense_router_prefill.py — prefill path - torch.nn.functional.linear for GEMM (DeepGEMM integration deferred) - Calls activation_topk for fused post-GEMM processing Step 6: Router class + ops/router.py + test_router.py - Router: construction-time mode (dense/hash), weight loading, custom_op dispatch - ops/router.py: torch.library.custom_op wrappers, integer-keyed registry - test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C) Test strategy: each kernel tested against its mathematical spec in FP32. No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00
# tests/unit/test_dense_router.py
import torch
from dsv4.layers.router import Router
def test_dense_router_matches_spec(N=64, H=4096, E=256, k=6):
X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda')
W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda')
bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
scaling = 2.5
# Oracle: directly compute the spec, in one expression, in FP32.
# This is not "a PyTorch reference implementation" — it's the math.
logits = (X.float() @ W.float())
act = torch.sqrt(torch.nn.functional.softplus(logits))
score = act + bias
ids = score.topk(k, dim=-1).indices
w = act.gather(-1, ids)
w = w / w.sum(-1, keepdim=True) * scaling
# Kernel under test:
router = Router(H, E, k, scaling, mode='dense')
router.W_gate.copy_(W)
router.e_bias.copy_(bias)
out_w, out_ids = router(X, layer_idx=5)
assert (out_ids == ids).all() # ids must be exact match
torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)