Commit Graph

6 Commits

Author SHA1 Message Date
cb2ca8591f fix: add @cute.jit to router compiled function 2026-05-31 23:44:53 +00:00
d5d2b7b4b8 fix: defer router MMA/TMA setup into cute.compile context (matches MoE pattern) 2026-05-31 23:44:00 +00:00
157f1c5258 fix: use OperandMajorMode from nvgpu (not deprecated tcgen05) and mma_tiler_mn in router kernel 2026-05-31 23:39:50 +00:00
1dbc57e2cd fix: use mma_tiler_mn in _create_tiled_mma (attribute exists at init time) 2026-05-31 23:36:01 +00:00
d05dd50bf5 fix: OperandMajorMode.K not MAJOR_K (correct CuTeDSL API) 2026-05-31 23:34:54 +00:00
0d06e55770 Router: Blackwell-native fused decode kernel — real CuTeDSL implementation
DenseRouterDecodeKernel: BF16 GEMM + sqrt(softplus) + bias + top-k
in a single kernel launch on Blackwell SM100.

Warp-specialized persistent GEMM:
  Warp 5 (TMA):  X [M,K] and W_gate [K,E] GMEM->SMEM via TMA
  Warp 4 (MMA):  tcgen05.mma BF16, FP32 accumulator -> TMEM
  Warps 0-3 (EPI): TMEM->register (tcgen05.ld), activation, top-k, store

Key design decisions:
- No EFC framework: our epilogue is a ROW-LEVEL top-k reduction,
  not a per-element transformation. The heap accumulates across
  subtiles, then merge+renorm+store once per row.
- Per-thread register heap: 6 entries (score, index, unbiased act)
  as CuTeDSL scalars (not Python lists — those dont compile to registers)
- Shared memory merge: 128 threads dump heaps, thread 0 merges final top-6
- Identity tensor for expert index: maps register position -> global e_idx
- Numerically stable softplus: max(x,0) + log(1+exp(-|x|)) in FP32

dense_router_decode.py now dispatches to this kernel for N<=64,
falls back to activation_topk.cu for N>64.

This is a real Blackwell kernel. No pass statements. No fake code.
2026-05-21 22:04:20 +00:00