Files
nvfp4-megamoe-kernel/tests/unit/test_router.py
biondizzle abfe4485f7 Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls

Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer

Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store

Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path

Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing

Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)

Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00

218 lines
7.7 KiB
Python

"""Unit tests for DSV4 Router — dense and hash modes.
Test strategy:
Each kernel has a closed-form mathematical spec. The unit test computes
the spec in one expression in FP32 (PyTorch) and compares against the
kernel output. This is not "a PyTorch reference implementation" — it's
the math. Compare against that. No "ref/" file, no second implementation
drift, no two debug streams.
The oracle is the same five lines of math as the kernel spec, written
declaratively. Compare against that.
DO NOT RUN THESE TESTS — Carmine is actively testing Stage C.
Write the tests, commit them, they'll be run later.
Tie-breaking: When two scores are exactly equal, torch.topk and the kernel
may pick different indices. Use the same tie-break rule: lower index wins
on ties. If the test fails on tie-breaking, fix the kernel, not the test.
"""
import torch
import math
def test_fused_activation_topk(N=64, E=256, k=6, seed=42):
"""Test the fused activation + top-k kernel against the math spec.
Oracle:
logits = X @ W (FP32)
act = sqrt(softplus(logits))
score = act + bias
ids = argtopk(score, k) with lower-index tie-break
raw_w = gather(act, ids)
topk_w = raw_w / sum(raw_w) * scaling
"""
torch.manual_seed(seed)
scaling = 2.5
logits = torch.randn(N, E, dtype=torch.float32, device='cuda')
e_bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
# Oracle — the math, one expression at a time
act = torch.sqrt(torch.nn.functional.softplus(logits))
score = act + e_bias
# torch.topk tie-breaking: picks lower index on ties (matches our kernel)
topk_result = score.topk(k, dim=-1)
ids = topk_result.indices
raw_w = act.gather(-1, ids)
w = raw_w / raw_w.sum(-1, keepdim=True) * scaling
# Kernel under test:
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
run_fused_activation_topk(logits, e_bias, scaling, k, out_w, out_ids)
# Verify
assert (out_ids == ids).all(), f"top-k indices mismatch"
torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)
def test_fused_activation_topk_decode_shapes():
"""Test the activation+topk kernel at decode-relevant N values."""
for N in [1, 4, 16, 64]:
test_fused_activation_topk(N=N, E=256, k=6, seed=N)
def test_fused_activation_topk_pro_experts():
"""Test with 384 experts (Pro model)."""
test_fused_activation_topk(N=64, E=384, k=6, seed=123)
def test_hash_router(N=128, vocab_size=128000, k=6, num_experts=256, seed=42):
"""Test the hash router against the math spec.
Oracle:
topk_ids[n, h] = hash_lut[token_ids[n], h]
topk_w[n, h] = 1.0 / k
"""
torch.manual_seed(seed)
# Build a random LUT
hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda')
token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda')
# Oracle — literally just indexing
expected_ids = hash_lut[token_ids] # [N, k]
expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda')
# Kernel under test:
from dsv4.kernels.router import hash_router_dispatch
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
hash_router_dispatch(token_ids, hash_lut, k, out_w, out_ids)
assert (out_ids == expected_ids).all(), f"hash router IDs mismatch"
torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7)
def test_hash_router_edge_cases():
"""Test hash router with N=1 and N=max_num_tokens."""
test_hash_router(N=1, vocab_size=128000, k=6)
test_hash_router(N=8192, vocab_size=128000, k=6)
def test_topk_select(N=64, E=256, k=6, seed=42):
"""Test standalone top-k selection against torch.topk.
Oracle:
(values, indices) = score.topk(k, dim=-1)
Lower index wins on ties (torch.topk default).
"""
torch.manual_seed(seed)
scores = torch.randn(N, E, dtype=torch.float32, device='cuda')
# Oracle
expected = scores.topk(k, dim=-1)
expected_ids = expected.indices
expected_values = expected.values
# Kernel under test:
from dsv4.ops.topk import topk_select
out_values, out_ids = topk_select(scores, k)
assert (out_ids == expected_ids).all(), f"top-k IDs mismatch"
torch.testing.assert_close(out_values, expected_values, atol=1e-6, rtol=1e-6)
def test_dense_router_decode(N=64, H=4096, E=256, k=6, seed=42):
"""Test the full dense router (GEMM + activation + topk) against the spec.
Oracle:
logits = (X.float() @ W.float())
act = sqrt(softplus(logits))
score = act + bias
ids = score.topk(k).indices
w = act.gather(-1, ids)
w = w / w.sum(-1, keepdim=True) * scaling
"""
torch.manual_seed(seed)
scaling = 2.5
X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda')
W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda')
bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
# Oracle — the math, in one expression, in FP32
logits = (X.float() @ W.float())
act = torch.sqrt(torch.nn.functional.softplus(logits))
score = act + bias
ids = score.topk(k, dim=-1).indices
w = act.gather(-1, ids)
w = w / w.sum(-1, keepdim=True) * scaling
# Kernel under test:
from dsv4.layers.router import Router
router = Router(H, E, k, scaling, mode='dense', max_num_tokens=N)
router.load_weights(W_gate=W, e_bias=bias)
router.finalize_weights()
out_w, out_ids = router(X)
assert (out_ids == ids).all(), f"router IDs mismatch"
torch.testing.assert_close(out_w, w, atol=1e-3, rtol=1e-3)
def test_dense_router_decode_shapes():
"""Test dense router at decode-relevant N values."""
for N in [1, 4, 16, 64]:
test_dense_router_decode(N=N, H=4096, E=256, k=6, seed=N)
def test_hash_router_via_router_class():
"""Test the Router class in hash mode."""
vocab_size = 128000
k = 6
num_experts = 256
N = 64
hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda')
token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda')
# Oracle
expected_ids = hash_lut[token_ids]
expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda')
# Router class
from dsv4.layers.router import Router
router = Router(
hidden_size=4096, # not used in hash mode
num_experts=num_experts,
top_k=k,
mode='hash',
vocab_size=vocab_size,
max_num_tokens=N,
)
router.load_weights(hash_lut=hash_lut)
router.finalize_weights()
out_w, out_ids = router(hidden_states=None, token_ids=token_ids)
assert (out_ids == expected_ids).all(), f"hash router class IDs mismatch"
torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7)
def test_softplus_numerical_stability():
"""Verify the numerically stable softplus matches the spec.
For x = -100: softplus(x) ≈ 0, sqrt(softplus(x)) ≈ 0
For x = 0: softplus(x) = log(2) ≈ 0.693, sqrt ≈ 0.832
For x = 100: softplus(x) ≈ 100, sqrt(softplus(x)) ≈ 10
"""
# This tests the Python math, not the kernel. It's a sanity check
# that the formula max(x,0) + log1p(exp(-|x|)) works correctly.
x = torch.tensor([-100.0, 0.0, 100.0], dtype=torch.float32)
sp = torch.nn.functional.softplus(x)
act = torch.sqrt(sp)
expected = torch.tensor([0.0, math.sqrt(math.log(2.0)), 10.0], dtype=torch.float32)
torch.testing.assert_close(act, expected, atol=1e-3, rtol=1e-3)