218 lines
7.7 KiB
Python
218 lines
7.7 KiB
Python
|
|
"""Unit tests for DSV4 Router — dense and hash modes.
|
||
|
|
|
||
|
|
Test strategy:
|
||
|
|
Each kernel has a closed-form mathematical spec. The unit test computes
|
||
|
|
the spec in one expression in FP32 (PyTorch) and compares against the
|
||
|
|
kernel output. This is not "a PyTorch reference implementation" — it's
|
||
|
|
the math. Compare against that. No "ref/" file, no second implementation
|
||
|
|
drift, no two debug streams.
|
||
|
|
|
||
|
|
The oracle is the same five lines of math as the kernel spec, written
|
||
|
|
declaratively. Compare against that.
|
||
|
|
|
||
|
|
DO NOT RUN THESE TESTS — Carmine is actively testing Stage C.
|
||
|
|
Write the tests, commit them, they'll be run later.
|
||
|
|
|
||
|
|
Tie-breaking: When two scores are exactly equal, torch.topk and the kernel
|
||
|
|
may pick different indices. Use the same tie-break rule: lower index wins
|
||
|
|
on ties. If the test fails on tie-breaking, fix the kernel, not the test.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import math
|
||
|
|
|
||
|
|
|
||
|
|
def test_fused_activation_topk(N=64, E=256, k=6, seed=42):
|
||
|
|
"""Test the fused activation + top-k kernel against the math spec.
|
||
|
|
|
||
|
|
Oracle:
|
||
|
|
logits = X @ W (FP32)
|
||
|
|
act = sqrt(softplus(logits))
|
||
|
|
score = act + bias
|
||
|
|
ids = argtopk(score, k) with lower-index tie-break
|
||
|
|
raw_w = gather(act, ids)
|
||
|
|
topk_w = raw_w / sum(raw_w) * scaling
|
||
|
|
"""
|
||
|
|
torch.manual_seed(seed)
|
||
|
|
scaling = 2.5
|
||
|
|
|
||
|
|
logits = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
||
|
|
e_bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
|
||
|
|
|
||
|
|
# Oracle — the math, one expression at a time
|
||
|
|
act = torch.sqrt(torch.nn.functional.softplus(logits))
|
||
|
|
score = act + e_bias
|
||
|
|
# torch.topk tie-breaking: picks lower index on ties (matches our kernel)
|
||
|
|
topk_result = score.topk(k, dim=-1)
|
||
|
|
ids = topk_result.indices
|
||
|
|
raw_w = act.gather(-1, ids)
|
||
|
|
w = raw_w / raw_w.sum(-1, keepdim=True) * scaling
|
||
|
|
|
||
|
|
# Kernel under test:
|
||
|
|
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
||
|
|
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
||
|
|
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
||
|
|
run_fused_activation_topk(logits, e_bias, scaling, k, out_w, out_ids)
|
||
|
|
|
||
|
|
# Verify
|
||
|
|
assert (out_ids == ids).all(), f"top-k indices mismatch"
|
||
|
|
torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)
|
||
|
|
|
||
|
|
|
||
|
|
def test_fused_activation_topk_decode_shapes():
|
||
|
|
"""Test the activation+topk kernel at decode-relevant N values."""
|
||
|
|
for N in [1, 4, 16, 64]:
|
||
|
|
test_fused_activation_topk(N=N, E=256, k=6, seed=N)
|
||
|
|
|
||
|
|
|
||
|
|
def test_fused_activation_topk_pro_experts():
|
||
|
|
"""Test with 384 experts (Pro model)."""
|
||
|
|
test_fused_activation_topk(N=64, E=384, k=6, seed=123)
|
||
|
|
|
||
|
|
|
||
|
|
def test_hash_router(N=128, vocab_size=128000, k=6, num_experts=256, seed=42):
|
||
|
|
"""Test the hash router against the math spec.
|
||
|
|
|
||
|
|
Oracle:
|
||
|
|
topk_ids[n, h] = hash_lut[token_ids[n], h]
|
||
|
|
topk_w[n, h] = 1.0 / k
|
||
|
|
"""
|
||
|
|
torch.manual_seed(seed)
|
||
|
|
|
||
|
|
# Build a random LUT
|
||
|
|
hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda')
|
||
|
|
token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda')
|
||
|
|
|
||
|
|
# Oracle — literally just indexing
|
||
|
|
expected_ids = hash_lut[token_ids] # [N, k]
|
||
|
|
expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda')
|
||
|
|
|
||
|
|
# Kernel under test:
|
||
|
|
from dsv4.kernels.router import hash_router_dispatch
|
||
|
|
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
||
|
|
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
||
|
|
hash_router_dispatch(token_ids, hash_lut, k, out_w, out_ids)
|
||
|
|
|
||
|
|
assert (out_ids == expected_ids).all(), f"hash router IDs mismatch"
|
||
|
|
torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7)
|
||
|
|
|
||
|
|
|
||
|
|
def test_hash_router_edge_cases():
|
||
|
|
"""Test hash router with N=1 and N=max_num_tokens."""
|
||
|
|
test_hash_router(N=1, vocab_size=128000, k=6)
|
||
|
|
test_hash_router(N=8192, vocab_size=128000, k=6)
|
||
|
|
|
||
|
|
|
||
|
|
def test_topk_select(N=64, E=256, k=6, seed=42):
|
||
|
|
"""Test standalone top-k selection against torch.topk.
|
||
|
|
|
||
|
|
Oracle:
|
||
|
|
(values, indices) = score.topk(k, dim=-1)
|
||
|
|
Lower index wins on ties (torch.topk default).
|
||
|
|
"""
|
||
|
|
torch.manual_seed(seed)
|
||
|
|
scores = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
||
|
|
|
||
|
|
# Oracle
|
||
|
|
expected = scores.topk(k, dim=-1)
|
||
|
|
expected_ids = expected.indices
|
||
|
|
expected_values = expected.values
|
||
|
|
|
||
|
|
# Kernel under test:
|
||
|
|
from dsv4.ops.topk import topk_select
|
||
|
|
out_values, out_ids = topk_select(scores, k)
|
||
|
|
|
||
|
|
assert (out_ids == expected_ids).all(), f"top-k IDs mismatch"
|
||
|
|
torch.testing.assert_close(out_values, expected_values, atol=1e-6, rtol=1e-6)
|
||
|
|
|
||
|
|
|
||
|
|
def test_dense_router_decode(N=64, H=4096, E=256, k=6, seed=42):
|
||
|
|
"""Test the full dense router (GEMM + activation + topk) against the spec.
|
||
|
|
|
||
|
|
Oracle:
|
||
|
|
logits = (X.float() @ W.float())
|
||
|
|
act = sqrt(softplus(logits))
|
||
|
|
score = act + bias
|
||
|
|
ids = score.topk(k).indices
|
||
|
|
w = act.gather(-1, ids)
|
||
|
|
w = w / w.sum(-1, keepdim=True) * scaling
|
||
|
|
"""
|
||
|
|
torch.manual_seed(seed)
|
||
|
|
scaling = 2.5
|
||
|
|
|
||
|
|
X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda')
|
||
|
|
W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda')
|
||
|
|
bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
|
||
|
|
|
||
|
|
# Oracle — the math, in one expression, in FP32
|
||
|
|
logits = (X.float() @ W.float())
|
||
|
|
act = torch.sqrt(torch.nn.functional.softplus(logits))
|
||
|
|
score = act + bias
|
||
|
|
ids = score.topk(k, dim=-1).indices
|
||
|
|
w = act.gather(-1, ids)
|
||
|
|
w = w / w.sum(-1, keepdim=True) * scaling
|
||
|
|
|
||
|
|
# Kernel under test:
|
||
|
|
from dsv4.layers.router import Router
|
||
|
|
router = Router(H, E, k, scaling, mode='dense', max_num_tokens=N)
|
||
|
|
router.load_weights(W_gate=W, e_bias=bias)
|
||
|
|
router.finalize_weights()
|
||
|
|
out_w, out_ids = router(X)
|
||
|
|
|
||
|
|
assert (out_ids == ids).all(), f"router IDs mismatch"
|
||
|
|
torch.testing.assert_close(out_w, w, atol=1e-3, rtol=1e-3)
|
||
|
|
|
||
|
|
|
||
|
|
def test_dense_router_decode_shapes():
|
||
|
|
"""Test dense router at decode-relevant N values."""
|
||
|
|
for N in [1, 4, 16, 64]:
|
||
|
|
test_dense_router_decode(N=N, H=4096, E=256, k=6, seed=N)
|
||
|
|
|
||
|
|
|
||
|
|
def test_hash_router_via_router_class():
|
||
|
|
"""Test the Router class in hash mode."""
|
||
|
|
vocab_size = 128000
|
||
|
|
k = 6
|
||
|
|
num_experts = 256
|
||
|
|
N = 64
|
||
|
|
|
||
|
|
hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda')
|
||
|
|
token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda')
|
||
|
|
|
||
|
|
# Oracle
|
||
|
|
expected_ids = hash_lut[token_ids]
|
||
|
|
expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda')
|
||
|
|
|
||
|
|
# Router class
|
||
|
|
from dsv4.layers.router import Router
|
||
|
|
router = Router(
|
||
|
|
hidden_size=4096, # not used in hash mode
|
||
|
|
num_experts=num_experts,
|
||
|
|
top_k=k,
|
||
|
|
mode='hash',
|
||
|
|
vocab_size=vocab_size,
|
||
|
|
max_num_tokens=N,
|
||
|
|
)
|
||
|
|
router.load_weights(hash_lut=hash_lut)
|
||
|
|
router.finalize_weights()
|
||
|
|
out_w, out_ids = router(hidden_states=None, token_ids=token_ids)
|
||
|
|
|
||
|
|
assert (out_ids == expected_ids).all(), f"hash router class IDs mismatch"
|
||
|
|
torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7)
|
||
|
|
|
||
|
|
|
||
|
|
def test_softplus_numerical_stability():
|
||
|
|
"""Verify the numerically stable softplus matches the spec.
|
||
|
|
|
||
|
|
For x = -100: softplus(x) ≈ 0, sqrt(softplus(x)) ≈ 0
|
||
|
|
For x = 0: softplus(x) = log(2) ≈ 0.693, sqrt ≈ 0.832
|
||
|
|
For x = 100: softplus(x) ≈ 100, sqrt(softplus(x)) ≈ 10
|
||
|
|
"""
|
||
|
|
# This tests the Python math, not the kernel. It's a sanity check
|
||
|
|
# that the formula max(x,0) + log1p(exp(-|x|)) works correctly.
|
||
|
|
x = torch.tensor([-100.0, 0.0, 100.0], dtype=torch.float32)
|
||
|
|
sp = torch.nn.functional.softplus(x)
|
||
|
|
act = torch.sqrt(sp)
|
||
|
|
expected = torch.tensor([0.0, math.sqrt(math.log(2.0)), 10.0], dtype=torch.float32)
|
||
|
|
torch.testing.assert_close(act, expected, atol=1e-3, rtol=1e-3)
|