Files
nvfp4-megamoe-kernel/dsv4/ops/topk_select.py
biondizzle abfe4485f7 Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls

Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer

Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store

Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path

Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing

Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)

Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00

45 lines
1.5 KiB
Python

"""Python wrapper for the topk_select CUDA kernel.
Lazy-loads the topk_select extension (same pattern as dsv4/ops/topk.py).
This is the general top-k primitive reused by the router and the CSA indexer.
"""
import os
import torch
_kernel_module = None
def _get_kernel_module():
"""Lazy-load the topk_select CUDA extension."""
global _kernel_module
if _kernel_module is not None:
return _kernel_module
from torch.utils.cpp_extension import load
kernel_dir = os.path.join(os.path.dirname(__file__), "kernels", "cuda")
_kernel_module = load(
name="topk_select",
sources=[os.path.join(kernel_dir, "topk_select.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel_module
def topk_select(
scores: torch.Tensor, # [num_rows, E] float32, row-major contiguous
k: int, # number to select (currently only k=6 supported)
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select top-k indices and values from each row of scores.
Returns (values, indices) where:
values: [num_rows, k] float32 — top-k scores in descending order
indices: [num_rows, k] int32 — top-k indices (0-based, lower index wins on ties)
One block per row, 64 threads per block, per-thread register min-heap
with shared-memory merge. O(E * log k) per row.
"""
mod = _get_kernel_module()
return mod.topk_select(scores, k)