Files
nvfp4-megamoe-kernel/dsv4/kernels/cuda/hash_router.cu
biondizzle fb243a4133 Router: full kernel stack — hash, topk, activation+topk, dense decode/prefill
Step 1: Hash router (hash_router.cu)
- One thread per token, gather from [vocab_size, k] LUT
- Uniform 1/k weights, FP32 output
- 3 MB LUT fits in L2 for repeated decode calls

Step 2: topk_select.cu — general top-k primitive
- Per-thread register min-heap (k=6, compile-time unrolled)
- Shared memory merge: thread 0 merges 64 partial heaps
- Tie-breaking: lower index wins on equal scores
- Reusable by CSA indexer

Step 3: activation_topk.cu — fused sqrt(softplus) + bias + topk + renorm
- Single kernel: all 6 steps of the router math, no intermediate buffers
- Numerically stable softplus: max(x,0) + log1p(exp(-|x|))
- Per-thread heap with unbiased activation co-stored
- Shared memory merge → sort descending → renormalize → store

Step 4: dense_router_decode.py — CuTeDSL fused GEMM kernel (skeleton)
- BF16 GEMM with tcgen05.mma, FP32 accumulator
- Custom epilogue: activation + bias + top-k (structure defined, needs TMA/MMA boilerplate)
- Dispatch: N<=64 uses fused decode, N>64 uses prefill path

Step 5: dense_router_prefill.py — prefill path
- torch.nn.functional.linear for GEMM (DeepGEMM integration deferred)
- Calls activation_topk for fused post-GEMM processing

Step 6: Router class + ops/router.py + test_router.py
- Router: construction-time mode (dense/hash), weight loading, custom_op dispatch
- ops/router.py: torch.library.custom_op wrappers, integer-keyed registry
- test_router.py: spec oracle tests (DO NOT RUN — Carmine is testing Stage C)

Test strategy: each kernel tested against its mathematical spec in FP32.
No reference implementation, no two debug streams. The oracle IS the math.
2026-05-21 21:54:05 +00:00

114 lines
4.5 KiB
Plaintext

/**
* Hash router kernel for DeepSeek-V4.
*
* Deterministic per-token-ID expert assignment via precomputed lookup table.
* Used by the first 3 MoE layers (§2.1) — no gate GEMM, no hidden-state input.
*
* One thread per token. Each thread gathers k expert IDs from the LUT and
* writes uniform 1/k weights. Bandwidth-bound: 3 MB LUT fits in L2 for
* repeated decode calls.
*
* Launch: grid(N), block(k) — k threads per token cooperate on the gather.
* - Thread h in each block handles the h-th expert slot for that token.
* - No shared memory needed.
*
* Tie-breaking: not applicable (LUT is deterministic). If two tokens map to
* the same expert ID for different h slots, that's the checkpoint's problem.
*
* Bounds: vocab_size <= 256K, k <= 16, num_experts <= 384. All fit int32.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <cstdint>
__global__ void hash_router_kernel(
// Inputs
const int32_t* __restrict__ token_ids, // [N] — token indices into LUT
const int32_t* __restrict__ hash_lut, // [vocab_size, k] — expert IDs
int64_t lut_stride, // stride in elements (= k)
int32_t k, // experts per token (6)
int32_t vocab_size, // LUT row count
// Outputs
float* __restrict__ out_weights, // [N, k] — 1/k uniform
int64_t out_weights_stride, // stride in elements
int32_t* __restrict__ out_ids, // [N, k] — expert IDs
int64_t out_ids_stride // stride in elements
) {
int64_t n = blockIdx.x; // one block per token
int32_t h = threadIdx.x; // one thread per expert slot
if (n >= gridDim.x || h >= k) return;
int32_t tid = token_ids[n];
// Bounds check — invalid token IDs get expert 0, weight 0.
// This should never happen with correct tokenization, but we don't
// want silent OOB reads. Log a sentinel rather than crash.
if (tid < 0 || tid >= vocab_size) {
out_ids[n * out_ids_stride + h] = 0;
out_weights[n * out_weights_stride + h] = 0.0f;
return;
}
// Gather expert ID from LUT
int32_t expert_id = hash_lut[static_cast<int64_t>(tid) * lut_stride + h];
out_ids[n * out_ids_stride + h] = expert_id;
// Uniform weight: 1/k. This is FP32 — the Router contract specifies
// topk_weights as float32. No renormalization needed since it's
// already uniform and sums to 1.0 before routed_scaling_factor
// (scaling is applied in the Router layer, not here).
out_weights[n * out_weights_stride + h] = 1.0f / static_cast<float>(k);
}
std::tuple<torch::Tensor, torch::Tensor> hash_router_cuda(
torch::Tensor token_ids, // [N] int32
torch::Tensor hash_lut, // [vocab_size, k] int32
int64_t k,
torch::Tensor out_weights, // [N, k] float32, pre-allocated
torch::Tensor out_ids // [N, k] int32, pre-allocated
) {
int64_t N = token_ids.size(0);
int64_t vocab_size = hash_lut.size(0);
int64_t lut_stride = hash_lut.stride(0);
TORCH_CHECK(token_ids.scalar_type() == torch::kInt32, "token_ids must be int32");
TORCH_CHECK(hash_lut.scalar_type() == torch::kInt32, "hash_lut must be int32");
TORCH_CHECK(out_weights.scalar_type() == torch::kFloat32, "out_weights must be float32");
TORCH_CHECK(out_ids.scalar_type() == torch::kInt32, "out_ids must be int32");
TORCH_CHECK(out_weights.size(0) >= N, "out_weights too small");
TORCH_CHECK(out_ids.size(0) >= N, "out_ids too small");
if (N == 0) return std::make_tuple(out_weights, out_ids);
// Launch: one block per token, k threads per block.
// k=6 → 6 threads/block. Occupancy is fine — the kernel is bandwidth-bound
// and each thread does one gather + one store. No shared memory, no smem bank
// conflicts. The LUT (3 MB) stays hot in L2 across decode iterations.
dim3 grid(static_cast<uint32_t>(N));
dim3 block(static_cast<uint32_t>(k));
hash_router_kernel<<<grid, block>>>(
token_ids.data_ptr<int32_t>(),
hash_lut.data_ptr<int32_t>(),
lut_stride,
static_cast<int32_t>(k),
static_cast<int32_t>(vocab_size),
out_weights.data_ptr<float>(),
out_weights.stride(0),
out_ids.data_ptr<int32_t>(),
out_ids.stride(0)
);
return std::make_tuple(out_weights, out_ids);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("hash_router", &hash_router_cuda);
}