# tests/unit/test_dense_router.py import torch from dsv4.layers.router import Router def test_dense_router_matches_spec(N=64, H=4096, E=256, k=6): X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda') W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda') bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01 scaling = 2.5 # Oracle: directly compute the spec, in one expression, in FP32. # This is not "a PyTorch reference implementation" — it's the math. logits = (X.float() @ W.float()) act = torch.sqrt(torch.nn.functional.softplus(logits)) score = act + bias ids = score.topk(k, dim=-1).indices w = act.gather(-1, ids) w = w / w.sum(-1, keepdim=True) * scaling # Kernel under test: router = Router(H, E, k, scaling, mode='dense') router.W_gate.copy_(W) router.e_bias.copy_(bias) out_w, out_ids = router(X, layer_idx=5) assert (out_ids == ids).all() # ids must be exact match torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)