"""Python wrapper for the fused activation + top-k CUDA kernel. This module lazy-loads the CUDA extension (same pattern as dsv4/ops/topk.py) and provides the run_fused_activation_topk() function called by dense_router_dispatch. """ import os import torch _kernel_module = None def _get_kernel_module(): """Lazy-load the fused_activation_topk CUDA extension.""" global _kernel_module if _kernel_module is not None: return _kernel_module from torch.utils.cpp_extension import load kernel_dir = os.path.join(os.path.dirname(__file__), "..", "cuda") _kernel_module = load( name="fused_activation_topk", sources=[os.path.join(kernel_dir, "activation_topk.cu")], extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"], verbose=False, ) return _kernel_module def run_fused_activation_topk( logits: torch.Tensor, # [N, E] FP32 e_bias: torch.Tensor, # [E] FP32 routed_scaling_factor: float, top_k: int, out_weights: torch.Tensor, # [N, top_k] FP32, pre-allocated out_ids: torch.Tensor, # [N, top_k] int32, pre-allocated ): """Run the fused activation + top-k + renormalization kernel. Computes: act = sqrt(softplus(logits)) score = act + e_bias topk_ids = argtopk(score, k=top_k) (tie-break: lower index wins) raw_w = gather(act, topk_ids) (unbiased activation) topk_w = raw_w / sum(raw_w) * scaling (renormalized) """ mod = _get_kernel_module() return mod.fused_activation_topk( logits, e_bias, float(routed_scaling_factor), top_k, out_weights, out_ids, )