Step 1: Hash router (hash_router.cu) - One thread per token, gather from [vocab_size, k] LUT - Uniform 1/k weights, FP32 output - 3 MB LUT fits in L2 for repeated decode calls Step 2: topk_select.cu — general top-k primitive - Per-thread register min-heap (k=6, compile-time unrolled) - Shared memory merge: thread 0 merges 64 partial heaps - Tie-breaking: lower index wins on equal scores - Reusable by CSA indexer Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm - Single kernel: all 6 steps of the router math, no intermediate buffers - Numerically stable softplus: max(x,0) + log1p(exp(-|x|)) - Per-thread heap with unbiased activation co-stored - Shared memory merge → sort descending → renormalize → store Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton) - BF16 GEMM with tcgen05.mma, FP32 accumulator - Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate) - Dispatch: N<=64 uses fused decode, N>64 uses prefill path Step 5: dense_router_prefill.py — prefill path - torch.nn.functional.linear for GEMM (DeepGEMM integration deferred) - Calls activation_topk for fused post-GEMM processing Step 6: Router class + ops/router.py + test_router.py - Router: construction-time mode (dense/hash), weight loading, custom_op dispatch - ops/router.py: torch.library.custom_op wrappers, integer-keyed registry - test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C) Test strategy: each kernel tested against its mathematical spec in FP32. No reference implementation, no two debug streams. The oracle IS the math.
218 lines
7.7 KiB
Python
218 lines
7.7 KiB
Python
"""Unit tests for DSV4 Router — dense and hash modes.
|
|
|
|
Test strategy:
|
|
Each kernel has a closed-form mathematical spec. The unit test computes
|
|
the spec in one expression in FP32 (PyTorch) and compares against the
|
|
kernel output. This is not "a PyTorch reference implementation" — it's
|
|
the math. Compare against that. No "ref/" file, no second implementation
|
|
drift, no two debug streams.
|
|
|
|
The oracle is the same five lines of math as the kernel spec, written
|
|
declaratively. Compare against that.
|
|
|
|
DO NOT RUN THESE TESTS — Carmine is actively testing Stage C.
|
|
Write the tests, commit them, they'll be run later.
|
|
|
|
Tie-breaking: When two scores are exactly equal, torch.topk and the kernel
|
|
may pick different indices. Use the same tie-break rule: lower index wins
|
|
on ties. If the test fails on tie-breaking, fix the kernel, not the test.
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
|
|
|
|
def test_fused_activation_topk(N=64, E=256, k=6, seed=42):
|
|
"""Test the fused activation + top-k kernel against the math spec.
|
|
|
|
Oracle:
|
|
logits = X @ W (FP32)
|
|
act = sqrt(softplus(logits))
|
|
score = act + bias
|
|
ids = argtopk(score, k) with lower-index tie-break
|
|
raw_w = gather(act, ids)
|
|
topk_w = raw_w / sum(raw_w) * scaling
|
|
"""
|
|
torch.manual_seed(seed)
|
|
scaling = 2.5
|
|
|
|
logits = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
|
e_bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
|
|
|
|
# Oracle — the math, one expression at a time
|
|
act = torch.sqrt(torch.nn.functional.softplus(logits))
|
|
score = act + e_bias
|
|
# torch.topk tie-breaking: picks lower index on ties (matches our kernel)
|
|
topk_result = score.topk(k, dim=-1)
|
|
ids = topk_result.indices
|
|
raw_w = act.gather(-1, ids)
|
|
w = raw_w / raw_w.sum(-1, keepdim=True) * scaling
|
|
|
|
# Kernel under test:
|
|
from dsv4.kernels.router._activation_topk import run_fused_activation_topk
|
|
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
|
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
|
run_fused_activation_topk(logits, e_bias, scaling, k, out_w, out_ids)
|
|
|
|
# Verify
|
|
assert (out_ids == ids).all(), f"top-k indices mismatch"
|
|
torch.testing.assert_close(out_w, w, atol=1e-4, rtol=1e-3)
|
|
|
|
|
|
def test_fused_activation_topk_decode_shapes():
|
|
"""Test the activation+topk kernel at decode-relevant N values."""
|
|
for N in [1, 4, 16, 64]:
|
|
test_fused_activation_topk(N=N, E=256, k=6, seed=N)
|
|
|
|
|
|
def test_fused_activation_topk_pro_experts():
|
|
"""Test with 384 experts (Pro model)."""
|
|
test_fused_activation_topk(N=64, E=384, k=6, seed=123)
|
|
|
|
|
|
def test_hash_router(N=128, vocab_size=128000, k=6, num_experts=256, seed=42):
|
|
"""Test the hash router against the math spec.
|
|
|
|
Oracle:
|
|
topk_ids[n, h] = hash_lut[token_ids[n], h]
|
|
topk_w[n, h] = 1.0 / k
|
|
"""
|
|
torch.manual_seed(seed)
|
|
|
|
# Build a random LUT
|
|
hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda')
|
|
token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda')
|
|
|
|
# Oracle — literally just indexing
|
|
expected_ids = hash_lut[token_ids] # [N, k]
|
|
expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda')
|
|
|
|
# Kernel under test:
|
|
from dsv4.kernels.router import hash_router_dispatch
|
|
out_w = torch.empty(N, k, dtype=torch.float32, device='cuda')
|
|
out_ids = torch.empty(N, k, dtype=torch.int32, device='cuda')
|
|
hash_router_dispatch(token_ids, hash_lut, k, out_w, out_ids)
|
|
|
|
assert (out_ids == expected_ids).all(), f"hash router IDs mismatch"
|
|
torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7)
|
|
|
|
|
|
def test_hash_router_edge_cases():
|
|
"""Test hash router with N=1 and N=max_num_tokens."""
|
|
test_hash_router(N=1, vocab_size=128000, k=6)
|
|
test_hash_router(N=8192, vocab_size=128000, k=6)
|
|
|
|
|
|
def test_topk_select(N=64, E=256, k=6, seed=42):
|
|
"""Test standalone top-k selection against torch.topk.
|
|
|
|
Oracle:
|
|
(values, indices) = score.topk(k, dim=-1)
|
|
Lower index wins on ties (torch.topk default).
|
|
"""
|
|
torch.manual_seed(seed)
|
|
scores = torch.randn(N, E, dtype=torch.float32, device='cuda')
|
|
|
|
# Oracle
|
|
expected = scores.topk(k, dim=-1)
|
|
expected_ids = expected.indices
|
|
expected_values = expected.values
|
|
|
|
# Kernel under test:
|
|
from dsv4.ops.topk import topk_select
|
|
out_values, out_ids = topk_select(scores, k)
|
|
|
|
assert (out_ids == expected_ids).all(), f"top-k IDs mismatch"
|
|
torch.testing.assert_close(out_values, expected_values, atol=1e-6, rtol=1e-6)
|
|
|
|
|
|
def test_dense_router_decode(N=64, H=4096, E=256, k=6, seed=42):
|
|
"""Test the full dense router (GEMM + activation + topk) against the spec.
|
|
|
|
Oracle:
|
|
logits = (X.float() @ W.float())
|
|
act = sqrt(softplus(logits))
|
|
score = act + bias
|
|
ids = score.topk(k).indices
|
|
w = act.gather(-1, ids)
|
|
w = w / w.sum(-1, keepdim=True) * scaling
|
|
"""
|
|
torch.manual_seed(seed)
|
|
scaling = 2.5
|
|
|
|
X = torch.randn(N, H, dtype=torch.bfloat16, device='cuda')
|
|
W = torch.randn(H, E, dtype=torch.bfloat16, device='cuda')
|
|
bias = torch.randn(E, dtype=torch.float32, device='cuda') * 0.01
|
|
|
|
# Oracle — the math, in one expression, in FP32
|
|
logits = (X.float() @ W.float())
|
|
act = torch.sqrt(torch.nn.functional.softplus(logits))
|
|
score = act + bias
|
|
ids = score.topk(k, dim=-1).indices
|
|
w = act.gather(-1, ids)
|
|
w = w / w.sum(-1, keepdim=True) * scaling
|
|
|
|
# Kernel under test:
|
|
from dsv4.layers.router import Router
|
|
router = Router(H, E, k, scaling, mode='dense', max_num_tokens=N)
|
|
router.load_weights(W_gate=W, e_bias=bias)
|
|
router.finalize_weights()
|
|
out_w, out_ids = router(X)
|
|
|
|
assert (out_ids == ids).all(), f"router IDs mismatch"
|
|
torch.testing.assert_close(out_w, w, atol=1e-3, rtol=1e-3)
|
|
|
|
|
|
def test_dense_router_decode_shapes():
|
|
"""Test dense router at decode-relevant N values."""
|
|
for N in [1, 4, 16, 64]:
|
|
test_dense_router_decode(N=N, H=4096, E=256, k=6, seed=N)
|
|
|
|
|
|
def test_hash_router_via_router_class():
|
|
"""Test the Router class in hash mode."""
|
|
vocab_size = 128000
|
|
k = 6
|
|
num_experts = 256
|
|
N = 64
|
|
|
|
hash_lut = torch.randint(0, num_experts, (vocab_size, k), dtype=torch.int32, device='cuda')
|
|
token_ids = torch.randint(0, vocab_size, (N,), dtype=torch.int32, device='cuda')
|
|
|
|
# Oracle
|
|
expected_ids = hash_lut[token_ids]
|
|
expected_w = torch.full((N, k), 1.0 / k, dtype=torch.float32, device='cuda')
|
|
|
|
# Router class
|
|
from dsv4.layers.router import Router
|
|
router = Router(
|
|
hidden_size=4096, # not used in hash mode
|
|
num_experts=num_experts,
|
|
top_k=k,
|
|
mode='hash',
|
|
vocab_size=vocab_size,
|
|
max_num_tokens=N,
|
|
)
|
|
router.load_weights(hash_lut=hash_lut)
|
|
router.finalize_weights()
|
|
out_w, out_ids = router(hidden_states=None, token_ids=token_ids)
|
|
|
|
assert (out_ids == expected_ids).all(), f"hash router class IDs mismatch"
|
|
torch.testing.assert_close(out_w, expected_w, atol=1e-7, rtol=1e-7)
|
|
|
|
|
|
def test_softplus_numerical_stability():
|
|
"""Verify the numerically stable softplus matches the spec.
|
|
|
|
For x = -100: softplus(x) ≈ 0, sqrt(softplus(x)) ≈ 0
|
|
For x = 0: softplus(x) = log(2) ≈ 0.693, sqrt ≈ 0.832
|
|
For x = 100: softplus(x) ≈ 100, sqrt(softplus(x)) ≈ 10
|
|
"""
|
|
# This tests the Python math, not the kernel. It's a sanity check
|
|
# that the formula max(x,0) + log1p(exp(-|x|)) works correctly.
|
|
x = torch.tensor([-100.0, 0.0, 100.0], dtype=torch.float32)
|
|
sp = torch.nn.functional.softplus(x)
|
|
act = torch.sqrt(sp)
|
|
expected = torch.tensor([0.0, math.sqrt(math.log(2.0)), 10.0], dtype=torch.float32)
|
|
torch.testing.assert_close(act, expected, atol=1e-3, rtol=1e-3)
|