Commit Graph

15 Commits

Author SHA1 Message Date
ef4c0ad489 Fix BF16 router mma_tiler: use cutlass.Int32 for CuTe DSL compatibility 2026-06-01 07:29:30 +00:00
5591a725e1 fix: router kernel — infer OperandMajorMode from tensor layout (same pattern as MoE GEMM) 2026-06-01 00:59:18 +00:00
c339fe7ad9 fix: router A operand major mode MN (not K) — fixes CuTeDSL local_tile coord error 2026-06-01 00:54:19 +00:00
3b5b9f487c fix: compute num_tma_load_bytes inside cute.compile context 2026-05-31 23:53:13 +00:00
1bc0da0f35 fix: properly scope swap code inside else/guard blocks, replace continue with if guard 2026-05-31 23:51:43 +00:00
d0d765e1f2 fix: replace break statements with flag-based loops in router kernel (CuTeDSL restriction) 2026-05-31 23:50:39 +00:00
210391e571 fix: PersistentTileSchedulerParams constructor takes (problem_shape, cluster_shape) not from_shape 2026-05-31 23:49:12 +00:00
824d054ad7 fix: inside cute.compile args are already CuTe tensors, no conversion needed 2026-05-31 23:47:33 +00:00
6375e54396 fix: use from_dlpack + mark_layout_dynamic instead of non-existent to_cuTe_tensor in router 2026-05-31 23:46:35 +00:00
cb2ca8591f fix: add @cute.jit to router compiled function 2026-05-31 23:44:53 +00:00
d5d2b7b4b8 fix: defer router MMA/TMA setup into cute.compile context (matches MoE pattern) 2026-05-31 23:44:00 +00:00
157f1c5258 fix: use OperandMajorMode from nvgpu (not deprecated tcgen05) and mma_tiler_mn in router kernel 2026-05-31 23:39:50 +00:00
1dbc57e2cd fix: use mma_tiler_mn in _create_tiled_mma (attribute exists at init time) 2026-05-31 23:36:01 +00:00
d05dd50bf5 fix: OperandMajorMode.K not MAJOR_K (correct CuTeDSL API) 2026-05-31 23:34:54 +00:00
0d06e55770 Router: Blackwell-native fused decode kernel — real CuTeDSL implementation
DenseRouterDecodeKernel: BF16 GEMM + sqrt(softplus) + bias + top-k
in a single kernel launch on Blackwell SM100.

Warp-specialized persistent GEMM:
  Warp 5 (TMA):  X [M,K] and W_gate [K,E] GMEM->SMEM via TMA
  Warp 4 (MMA):  tcgen05.mma BF16, FP32 accumulator -> TMEM
  Warps 0-3 (EPI): TMEM->register (tcgen05.ld), activation, top-k, store

Key design decisions:
- No EFC framework: our epilogue is a ROW-LEVEL top-k reduction,
  not a per-element transformation. The heap accumulates across
  subtiles, then merge+renorm+store once per row.
- Per-thread register heap: 6 entries (score, index, unbiased act)
  as CuTeDSL scalars (not Python lists — those dont compile to registers)
- Shared memory merge: 128 threads dump heaps, thread 0 merges final top-6
- Identity tensor for expert index: maps register position -> global e_idx
- Numerically stable softplus: max(x,0) + log(1+exp(-|x|)) in FP32

dense_router_decode.py now dispatches to this kernel for N<=64,
falls back to activation_topk.cu for N>64.

This is a real Blackwell kernel. No pass statements. No fake code.
2026-05-21 22:04:20 +00:00