"""DSV4 Dense Router — fused BF16 GEMM + sqrt(softplus) + bias + top-k for decode. Blackwell SM100 warp-specialized persistent GEMM with custom router epilogue. See dense_router_decode_epilogue.py for the epilogue implementation. """ from __future__ import annotations from typing import Tuple, Optional import torch def dense_router_dispatch( hidden_states: torch.Tensor, # [N, hidden_size] BF16 W_gate: torch.Tensor, # [hidden_size, num_experts] BF16 e_bias: torch.Tensor, # [num_experts] FP32 routed_scaling_factor: float, top_k: int, out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated ): """Dispatch the dense router kernel. For decode (N <= 64): uses the fused CuTeDSL kernel. For prefill (N > 64): uses torch.nn.functional.linear + activation_topk. """ N = hidden_states.shape[0] if N <= 64: try: _run_fused_decode( hidden_states, W_gate, e_bias, routed_scaling_factor, top_k, out_weights, out_ids, ) return except (ImportError, NotImplementedError): pass # fall through to prefill path _run_prefill_path( hidden_states, W_gate, e_bias, routed_scaling_factor, top_k, out_weights, out_ids, ) def _run_prefill_path( hidden_states, W_gate, e_bias, routed_scaling_factor, top_k, out_weights, out_ids, ): """GEMM via torch.nn.functional.linear, then fused activation + top-k.""" logits = torch.nn.functional.linear(hidden_states.float(), W_gate.float()) from dsv4.kernels.router._activation_topk import run_fused_activation_topk run_fused_activation_topk( logits, e_bias, routed_scaling_factor, top_k, out_weights, out_ids, ) def _run_fused_decode( hidden_states, W_gate, e_bias, routed_scaling_factor, top_k, out_weights, out_ids, ): """Run the fused CuTeDSL decode kernel (BF16 GEMM + epilogue in one launch).""" from dsv4.kernels.router.dense_router_decode_kernel import DenseRouterDecodeKernel N = hidden_states.shape[0] E = W_gate.shape[1] K = W_gate.shape[0] kernel = DenseRouterDecodeKernel( mma_tiler_mn=(128, 128), cluster_shape_mn=(1, 1), top_k=top_k, ) kernel.run( hidden_states, W_gate, e_bias, out_weights, out_ids, N, E, K, routed_scaling_factor, top_k, )