Files
nvfp4-megamoe-kernel/dsv4/kernels/cuda/_hash_router.py

39 lines
1.2 KiB
Python
Raw Normal View History

Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill Step 1: Hash router (hash_router.cu) - One thread per token, gather from [vocab_size, k] LUT - Uniform 1/k weights, FP32 output - 3 MB LUT fits in L2 for repeated decode calls Step 2: topk_select.cu — general top-k primitive - Per-thread register min-heap (k=6, compile-time unrolled) - Shared memory merge: thread 0 merges 64 partial heaps - Tie-breaking: lower index wins on equal scores - Reusable by CSA indexer Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm - Single kernel: all 6 steps of the router math, no intermediate buffers - Numerically stable softplus: max(x,0) + log1p(exp(-|x|)) - Per-thread heap with unbiased activation co-stored - Shared memory merge → sort descending → renormalize → store Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton) - BF16 GEMM with tcgen05.mma, FP32 accumulator - Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate) - Dispatch: N<=64 uses fused decode, N>64 uses prefill path Step 5: dense_router_prefill.py — prefill path - torch.nn.functional.linear for GEMM (DeepGEMM integration deferred) - Calls activation_topk for fused post-GEMM processing Step 6: Router class + ops/router.py + test_router.py - Router: construction-time mode (dense/hash), weight loading, custom_op dispatch - ops/router.py: torch.library.custom_op wrappers, integer-keyed registry - test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C) Test strategy: each kernel tested against its mathematical spec in FP32. No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00
"""Python wrapper for the hash_router CUDA kernel.
Lazy-loads the hash_router extension (same pattern as dsv4/ops/topk.py).
"""
import os
import torch
_kernel_module = None
def _get_kernel_module():
"""Lazy-load the hash_router CUDA extension."""
global _kernel_module
if _kernel_module is not None:
return _kernel_module
from torch.utils.cpp_extension import load
kernel_dir = os.path.join(os.path.dirname(__file__))
_kernel_module = load(
name="hash_router",
sources=[os.path.join(kernel_dir, "hash_router.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def run_hash_router(
token_ids: torch.Tensor, # [N] int32
hash_lut: torch.Tensor, # [vocab_size, k] int32
top_k: int,
out_weights: torch.Tensor, # [N, k] float32, pre-allocated
out_ids: torch.Tensor, # [N, k] int32, pre-allocated
):
"""Run the hash router kernel: gather expert IDs from LUT, write 1/k weights."""
mod = _get_kernel_module()
return mod.hash_router(token_ids, hash_lut, top_k, out_weights, out_ids)