Step 1: Hash router (hash_router.cu) - One thread per token, gather from [vocab_size, k] LUT - Uniform 1/k weights, FP32 output - 3 MB LUT fits in L2 for repeated decode calls Step 2: topk_select.cu — general top-k primitive - Per-thread register min-heap (k=6, compile-time unrolled) - Shared memory merge: thread 0 merges 64 partial heaps - Tie-breaking: lower index wins on equal scores - Reusable by CSA indexer Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm - Single kernel: all 6 steps of the router math, no intermediate buffers - Numerically stable softplus: max(x,0) + log1p(exp(-|x|)) - Per-thread heap with unbiased activation co-stored - Shared memory merge → sort descending → renormalize → store Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton) - BF16 GEMM with tcgen05.mma, FP32 accumulator - Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate) - Dispatch: N<=64 uses fused decode, N>64 uses prefill path Step 5: dense_router_prefill.py — prefill path - torch.nn.functional.linear for GEMM (DeepGEMM integration deferred) - Calls activation_topk for fused post-GEMM processing Step 6: Router class + ops/router.py + test_router.py - Router: construction-time mode (dense/hash), weight loading, custom_op dispatch - ops/router.py: torch.library.custom_op wrappers, integer-keyed registry - test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C) Test strategy: each kernel tested against its mathematical spec in FP32. No reference implementation, no two debug streams. The oracle IS the math.
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
"""Python wrapper for the topk_select CUDA kernel.
|
|
|
|
Lazy-loads the topk_select extension (same pattern as dsv4/ops/topk.py).
|
|
This is the general top-k primitive reused by the router and the CSA indexer.
|
|
"""
|
|
|
|
import os
|
|
import torch
|
|
|
|
_kernel_module = None
|
|
|
|
|
|
def _get_kernel_module():
|
|
"""Lazy-load the topk_select CUDA extension."""
|
|
global _kernel_module
|
|
if _kernel_module is not None:
|
|
return _kernel_module
|
|
|
|
from torch.utils.cpp_extension import load
|
|
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels", "cuda")
|
|
_kernel_module = load(
|
|
name="topk_select",
|
|
sources=[os.path.join(kernel_dir, "topk_select.cu")],
|
|
extra_cuda_cflags=["-O3", "--generate-code=arch=arch=compute_100a,code=[sm_100a]"],
|
|
verbose=False,
|
|
)
|
|
return _kernel_module
|
|
|
|
|
|
def topk_select(
|
|
scores: torch.Tensor, # [num_rows, E] float32, row-major contiguous
|
|
k: int, # number to select (currently only k=6 supported)
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""Select top-k indices and values from each row of scores.
|
|
|
|
Returns (values, indices) where:
|
|
values: [num_rows, k] float32 — top-k scores in descending order
|
|
indices: [num_rows, k] int32 — top-k indices (0-based, lower index wins on ties)
|
|
|
|
One block per row, 64 threads per block, per-thread register min-heap
|
|
with shared-memory merge. O(E * log k) per row.
|
|
"""
|
|
mod = _get_kernel_module()
|
|
return mod.topk_select(scores, k)
|