"""Unit tests for DSV4 Router — dense and hash modes. Test strategy: Each kernel has a closed-form mathematical spec. The unit test computes the spec in one expression in FP32 (PyTorch) and compares against the kernel output. This is not "a PyTorch reference implementation" — it's the math. Compare against that. No "ref/" file, no second implementation drift, no two debug streams. The oracle is the same five lines of math as the kernel spec, written declaratively. Compare against that. DO NOT RUN THESE TESTS — Carmine is actively testing Stage C. Write the tests, commit them, they'll be run later. Tie-breaking: When two scores are exactly equal, torch.topk and the kernel may pick different indices. Use the same tie-break rule: lower index wins on ties. If the test fails on tie-breaking, fix the kernel, not the test. """ import torch import math def test_fused_activation_topk(N=64, E=256, k=6, seed=42): """Test the fused activation + top-k kernel against the math spec. Oracle: logits = X @ W (FP32) act = sqrt(softplus(logits)) score = act + bias ids = argtopk(score, k) with lower-index tie-break raw_w = gather(act, ids) topk_w = raw_w / sum(raw_w) * scaling """ torch.manual_seed(seed) scaling = 2.5 logits = torch.randn(N, E, dtype=torch.float32, device='cuda') e_bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01 # Oracle — the math, one expression at a time act = torch.sqrt(torch.nn.functional.softplus(logits)) score = act + e_bias # torch.topk tie-breaking: picks lower index on ties (matches our kernel) topk_result = score.topk(k, dim=-1) ids = topk_result.indices raw_w = act.gather(-1, ids) w = raw_w / raw_w.sum(-1, keepdim=True) * scaling # Kernel under test: from dsv4.kernels.router._activation_topk import run_fused_activation_topk out_w = torch.empty(N, k, dtype=torch.float32, device='cuda') out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda') run_fused_activation_topk(logits, e_bias, scaling, k, out_w, out_ids) # Verify assert (out_ids == ids).all(), f"top-k indices mismatch" torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3) def test_fused_activation_topk_decode_shapes(): """Test the activation+topk kernel at decode-relevant N values.""" for N in [1, 4, 16, 64]: test_fused_activation_topk(N=N, E=256, k=6, seed=N) def test_fused_activation_topk_pro_experts(): """Test with 384 experts (Pro model).""" test_fused_activation_topk(N=64, E=384, k=6, seed=123) def test_hash_router(N=128, vocab_size=128000, k=6, num_experts=256, seed=42): """Test the hash router against the math spec. Oracle: topk_ids[n, h] = hash_lut[token_ids[n], h] topk_w[n, h] = 1.0 / k """ torch.manual_seed(seed) # Build a random LUT hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda') token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda') # Oracle — literally just indexing expected_ids = hash_lut[token_ids] # [N, k] expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda') # Kernel under test: from dsv4.kernels.router import hash_router_dispatch out_w = torch.empty(N, k, dtype=torch.float32, device='cuda') out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda') hash_router_dispatch(token_ids, hash_lut, k, out_w, out_ids) assert (out_ids == expected_ids).all(), f"hash router IDs mismatch" torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7) def test_hash_router_edge_cases(): """Test hash router with N=1 and N=max_num_tokens.""" test_hash_router(N=1, vocab_size=128000, k=6) test_hash_router(N=8192, vocab_size=128000, k=6) def test_topk_select(N=64, E=256, k=6, seed=42): """Test standalone top-k selection against torch.topk. Oracle: (values, indices) = score.topk(k, dim=-1) Lower index wins on ties (torch.topk default). """ torch.manual_seed(seed) scores = torch.randn(N, E, dtype=torch.float32, device='cuda') # Oracle expected = scores.topk(k, dim=-1) expected_ids = expected.indices expected_values = expected.values # Kernel under test: from dsv4.ops.topk import topk_select out_values, out_ids = topk_select(scores, k) assert (out_ids == expected_ids).all(), f"top-k IDs mismatch" torch.testing.assert_close(out_values, expected_values, atol=1e-6, rtol=1e-6) def test_dense_router_decode(N=64, H=4096, E=256, k=6, seed=42): """Test the full dense router (GEMM + activation + topk) against the spec. Oracle: logits = (X.float() @ W.float()) act = sqrt(softplus(logits)) score = act + bias ids = score.topk(k).indices w = act.gather(-1, ids) w = w / w.sum(-1, keepdim=True) * scaling """ torch.manual_seed(seed) scaling = 2.5 X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda') W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda') bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01 # Oracle — the math, in one expression, in FP32 logits = (X.float() @ W.float()) act = torch.sqrt(torch.nn.functional.softplus(logits)) score = act + bias ids = score.topk(k, dim=-1).indices w = act.gather(-1, ids) w = w / w.sum(-1, keepdim=True) * scaling # Kernel under test: from dsv4.layers.router import Router router = Router(H, E, k, scaling, mode='dense', max_num_tokens=N) router.load_weights(W_gate=W, e_bias=bias) router.finalize_weights() out_w, out_ids = router(X) assert (out_ids == ids).all(), f"router IDs mismatch" torch.testing.assert_close(out_w, w, atol=1e-3, rtol=1e-3) def test_dense_router_decode_shapes(): """Test dense router at decode-relevant N values.""" for N in [1, 4, 16, 64]: test_dense_router_decode(N=N, H=4096, E=256, k=6, seed=N) def test_hash_router_via_router_class(): """Test the Router class in hash mode.""" vocab_size = 128000 k = 6 num_experts = 256 N = 64 hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda') token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda') # Oracle expected_ids = hash_lut[token_ids] expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda') # Router class from dsv4.layers.router import Router router = Router( hidden_size=4096, # not used in hash mode num_experts=num_experts, top_k=k, mode='hash', vocab_size=vocab_size, max_num_tokens=N, ) router.load_weights(hash_lut=hash_lut) router.finalize_weights() out_w, out_ids = router(hidden_states=None, token_ids=token_ids) assert (out_ids == expected_ids).all(), f"hash router class IDs mismatch" torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7) def test_softplus_numerical_stability(): """Verify the numerically stable softplus matches the spec. For x = -100: softplus(x) ≈ 0, sqrt(softplus(x)) ≈ 0 For x = 0: softplus(x) = log(2) ≈ 0.693, sqrt ≈ 0.832 For x = 100: softplus(x) ≈ 100, sqrt(softplus(x)) ≈ 10 """ # This tests the Python math, not the kernel. It's a sanity check # that the formula max(x,0) + log1p(exp(-|x|)) works correctly. x = torch.tensor([-100.0, 0.0, 100.0], dtype=torch.float32) sp = torch.nn.functional.softplus(x) act = torch.sqrt(sp) expected = torch.tensor([0.0, math.sqrt(math.log(2.0)), 10.0], dtype=torch.float32) torch.testing.assert_close(act, expected, atol=1e-3, rtol=1e-3)