Files
nvfp4-megamoe-kernel/dsv4/ops/router.py
biondizzle abfe4485f7 Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls

Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer

Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store

Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path

Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing

Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)

Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00

90 lines
2.5 KiB
Python

"""torch.library.custom_op wrappers and dispatch for the Router kernels.
Mirrors the pattern in dsv4/ops/custom_ops.py:
- Routers are registered into an integer-keyed table.
- The custom_op takes the integer ID and tensor args only.
- Dynamo can't trace through the kernel; the op is opaque.
"""
import torch
from dsv4.kernels.router import (
dense_router_dispatch, # picks decode vs prefill internally
hash_router_dispatch,
)
_next_router_id = 0
_router_registry: dict[int, object] = {}
def register_router(router) -> int:
global _next_router_id
rid = _next_router_id
_next_router_id += 1
_router_registry[rid] = router
return rid
def get_router(rid: int):
return _router_registry[rid]
def warmup_router_compilation(router) -> None:
"""Trigger eager JIT compilation for the router's kernel path.
Runs a dummy forward at max_num_tokens to compile the kernel for the
expected shape range. Caller already has the buffers allocated.
"""
if router.mode == "dense":
# Dummy forward at small N triggers decode-path compile.
dummy = torch.zeros(
1, router.hidden_size,
dtype=torch.bfloat16, device=router.device,
)
router._run_dense_impl(dummy)
else:
dummy = torch.zeros(1, dtype=torch.int32, device=router.device)
router._run_hash_impl(dummy)
# ----- Dense router custom op -----
@torch.library.custom_op("dsv4::dense_router", mutates_args=())
def dense_router_op(
hidden_states: torch.Tensor,
router_id: int,
num_experts: int,
top_k: int,
) -> tuple[torch.Tensor, torch.Tensor]:
router = get_router(router_id)
return router._run_dense_impl(hidden_states)
@dense_router_op.register_fake
def _(hidden_states, router_id, num_experts, top_k):
N = hidden_states.shape[0]
device = hidden_states.device
return (
torch.empty(N, top_k, dtype=torch.float32, device=device),
torch.empty(N, top_k, dtype=torch.int32, device=device),
)
# ----- Hash router custom op -----
@torch.library.custom_op("dsv4::hash_router", mutates_args=())
def hash_router_op(
token_ids: torch.Tensor,
router_id: int,
top_k: int,
) -> tuple[torch.Tensor, torch.Tensor]:
router = get_router(router_id)
return router._run_hash_impl(token_ids)
@hash_router_op.register_fake
def _(token_ids, router_id, top_k):
N = token_ids.shape[0]
device = token_ids.device
return (
torch.empty(N, top_k, dtype=torch.float32, device=device),
torch.empty(N, top_k, dtype=torch.int32, device=device),
)