Step 1: Hash router (hash_router.cu) - One thread per token, gather from [vocab_size, k] LUT - Uniform 1/k weights, FP32 output - 3 MB LUT fits in L2 for repeated decode calls Step 2: topk_select.cu — general top-k primitive - Per-thread register min-heap (k=6, compile-time unrolled) - Shared memory merge: thread 0 merges 64 partial heaps - Tie-breaking: lower index wins on equal scores - Reusable by CSA indexer Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm - Single kernel: all 6 steps of the router math, no intermediate buffers - Numerically stable softplus: max(x,0) + log1p(exp(-|x|)) - Per-thread heap with unbiased activation co-stored - Shared memory merge → sort descending → renormalize → store Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton) - BF16 GEMM with tcgen05.mma, FP32 accumulator - Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate) - Dispatch: N<=64 uses fused decode, N>64 uses prefill path Step 5: dense_router_prefill.py — prefill path - torch.nn.functional.linear for GEMM (DeepGEMM integration deferred) - Calls activation_topk for fused post-GEMM processing Step 6: Router class + ops/router.py + test_router.py - Router: construction-time mode (dense/hash), weight loading, custom_op dispatch - ops/router.py: torch.library.custom_op wrappers, integer-keyed registry - test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C) Test strategy: each kernel tested against its mathematical spec in FP32. No reference implementation, no two debug streams. The oracle IS the math.
90 lines
2.5 KiB
Python
90 lines
2.5 KiB
Python
"""torch.library.custom_op wrappers and dispatch for the Router kernels.
|
|
|
|
Mirrors the pattern in dsv4/ops/custom_ops.py:
|
|
- Routers are registered into an integer-keyed table.
|
|
- The custom_op takes the integer ID and tensor args only.
|
|
- Dynamo can't trace through the kernel; the op is opaque.
|
|
"""
|
|
|
|
import torch
|
|
from dsv4.kernels.router import (
|
|
dense_router_dispatch, # picks decode vs prefill internally
|
|
hash_router_dispatch,
|
|
)
|
|
|
|
_next_router_id = 0
|
|
_router_registry: dict[int, object] = {}
|
|
|
|
|
|
def register_router(router) -> int:
|
|
global _next_router_id
|
|
rid = _next_router_id
|
|
_next_router_id += 1
|
|
_router_registry[rid] = router
|
|
return rid
|
|
|
|
|
|
def get_router(rid: int):
|
|
return _router_registry[rid]
|
|
|
|
|
|
def warmup_router_compilation(router) -> None:
|
|
"""Trigger eager JIT compilation for the router's kernel path.
|
|
|
|
Runs a dummy forward at max_num_tokens to compile the kernel for the
|
|
expected shape range. Caller already has the buffers allocated.
|
|
"""
|
|
if router.mode == "dense":
|
|
# Dummy forward at small N triggers decode-path compile.
|
|
dummy = torch.zeros(
|
|
1, router.hidden_size,
|
|
dtype=torch.bfloat16, device=router.device,
|
|
)
|
|
router._run_dense_impl(dummy)
|
|
else:
|
|
dummy = torch.zeros(1, dtype=torch.int32, device=router.device)
|
|
router._run_hash_impl(dummy)
|
|
|
|
|
|
# ----- Dense router custom op -----
|
|
@torch.library.custom_op("dsv4::dense_router", mutates_args=())
|
|
def dense_router_op(
|
|
hidden_states: torch.Tensor,
|
|
router_id: int,
|
|
num_experts: int,
|
|
top_k: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
router = get_router(router_id)
|
|
return router._run_dense_impl(hidden_states)
|
|
|
|
|
|
@dense_router_op.register_fake
|
|
def _(hidden_states, router_id, num_experts, top_k):
|
|
N = hidden_states.shape[0]
|
|
device = hidden_states.device
|
|
return (
|
|
torch.empty(N, top_k, dtype=torch.float32, device=device),
|
|
torch.empty(N, top_k, dtype=torch.int32, device=device),
|
|
)
|
|
|
|
|
|
# ----- Hash router custom op -----
|
|
@torch.library.custom_op("dsv4::hash_router", mutates_args=())
|
|
def hash_router_op(
|
|
token_ids: torch.Tensor,
|
|
router_id: int,
|
|
top_k: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
router = get_router(router_id)
|
|
return router._run_hash_impl(token_ids)
|
|
|
|
|
|
@hash_router_op.register_fake
|
|
def _(token_ids, router_id, top_k):
|
|
N = token_ids.shape[0]
|
|
device = token_ids.device
|
|
return (
|
|
torch.empty(N, top_k, dtype=torch.float32, device=device),
|
|
torch.empty(N, top_k, dtype=torch.int32, device=device),
|
|
)
|