The first draft had a fake CuTeDSL kernel body with pass statements and Python lists as register heaps. That is not the right way. This commit replaces it with honest documentation of what the kernel does and what needs to happen. Current working path: - All N routes through torch.nn.functional.linear + activation_topk.cu - activation_topk is a single-pass fused CUDA kernel (all 6 steps) - This is correct and performant for all N CuTeDSL fused decode kernel (DenseRouterDecodeKernel): - Class structure and warp specialization defined - Full documentation of the TMA/MMA/epilogue pipeline - The novel part is the row-level top-k epilogue (cross-subtile heap) - EFC framework does not apply — our epilogue is not per-element - Implementation deferred until profiling shows the GMEM round-trip on logits matters for decode latency No fake code. No pass statements. No Python lists as GPU registers. The working path is the activation_topk kernel. The CuTeDSL kernel will be built on top of it when the optimization is needed.
25 lines
877 B
Python
25 lines
877 B
Python
"""DSV4 Router kernels — dispatch and CUDA kernel wrappers.
|
|
|
|
Exports:
|
|
dense_router_dispatch: GEMM + fused activation + top-k (all N)
|
|
hash_router_dispatch: Hash routing via precomputed LUT gather
|
|
"""
|
|
|
|
from dsv4.kernels.router.dense_router_decode import dense_router_dispatch
|
|
|
|
|
|
def hash_router_dispatch(
|
|
token_ids, # [N] int32
|
|
hash_lut, # [vocab_size, k] int32
|
|
top_k, # k=6
|
|
out_weights, # [N, k] float32, pre-allocated
|
|
out_ids, # [N, k] int32, pre-allocated
|
|
):
|
|
"""Hash router dispatch: gather expert IDs from precomputed LUT.
|
|
|
|
Wraps the hash_router CUDA kernel (dsv4/kernels/cuda/hash_router.cu).
|
|
One kernel launch, no intermediate buffers, no CPU-GPU sync.
|
|
"""
|
|
from dsv4.kernels.cuda._hash_router import run_hash_router
|
|
return run_hash_router(token_ids, hash_lut, top_k, out_weights, out_ids)
|