Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls
Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer
Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store
Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path
Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing
Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)
Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00
|
|
|
"""DSV4 Router kernels — dispatch and CUDA kernel wrappers.
|
|
|
|
|
|
|
|
|
|
Exports:
|
2026-05-21 21:58:31 +00:00
|
|
|
dense_router_dispatch: GEMM + fused activation + top-k (all N)
|
|
|
|
|
hash_router_dispatch: Hash routing via precomputed LUT gather
|
Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls
Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer
Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store
Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path
Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing
Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)
Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hash_router_dispatch(
|
|
|
|
|
token_ids, # [N] int32
|
|
|
|
|
hash_lut, # [vocab_size, k] int32
|
|
|
|
|
top_k, # k=6
|
|
|
|
|
out_weights, # [N, k] float32, pre-allocated
|
|
|
|
|
out_ids, # [N, k] int32, pre-allocated
|
|
|
|
|
):
|
|
|
|
|
"""Hash router dispatch: gather expert IDs from precomputed LUT.
|
|
|
|
|
|
|
|
|
|
Wraps the hash_router CUDA kernel (dsv4/kernels/cuda/hash_router.cu).
|
|
|
|
|
One kernel launch, no intermediate buffers, no CPU-GPU sync.
|
|
|
|
|
"""
|
|
|
|
|
from dsv4.kernels.cuda._hash_router import run_hash_router
|
|
|
|
|
return run_hash_router(token_ids, hash_lut, top_k, out_weights, out_ids)
|