28 lines
1.0 KiB
Python
28 lines
1.0 KiB
Python
|
|
# tests/unit/test_dense_router.py
|
||
|
|
import torch
|
||
|
|
from dsv4.layers.router import Router
|
||
|
|
|
||
|
|
def test_dense_router_matches_spec(N=64, H=4096, E=256, k=6):
|
||
|
|
X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda')
|
||
|
|
W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda')
|
||
|
|
bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
|
||
|
|
scaling = 2.5
|
||
|
|
|
||
|
|
# Oracle: directly compute the spec, in one expression, in FP32.
|
||
|
|
# This is not "a PyTorch reference implementation" — it's the math.
|
||
|
|
logits = (X.float() @ W.float())
|
||
|
|
act = torch.sqrt(torch.nn.functional.softplus(logits))
|
||
|
|
score = act + bias
|
||
|
|
ids = score.topk(k, dim=-1).indices
|
||
|
|
w = act.gather(-1, ids)
|
||
|
|
w = w / w.sum(-1, keepdim=True) * scaling
|
||
|
|
|
||
|
|
# Kernel under test:
|
||
|
|
router = Router(H, E, k, scaling, mode='dense')
|
||
|
|
router.W_gate.copy_(W)
|
||
|
|
router.e_bias.copy_(bias)
|
||
|
|
out_w, out_ids = router(X, layer_idx=5)
|
||
|
|
|
||
|
|
assert (out_ids == ids).all() # ids must be exact match
|
||
|
|
torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)
|