39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
|
|
"""Python wrapper for the hash_router CUDA kernel.
|
||
|
|
|
||
|
|
Lazy-loads the hash_router extension (same pattern as dsv4/ops/topk.py).
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import torch
|
||
|
|
|
||
|
|
_kernel_module = None
|
||
|
|
|
||
|
|
|
||
|
|
def _get_kernel_module():
|
||
|
|
"""Lazy-load the hash_router CUDA extension."""
|
||
|
|
global _kernel_module
|
||
|
|
if _kernel_module is not None:
|
||
|
|
return _kernel_module
|
||
|
|
|
||
|
|
from torch.utils.cpp_extension import load
|
||
|
|
kernel_dir = os.path.join(os.path.dirname(__file__))
|
||
|
|
_kernel_module = load(
|
||
|
|
name="hash_router",
|
||
|
|
sources=[os.path.join(kernel_dir, "hash_router.cu")],
|
||
|
|
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
|
||
|
|
verbose=False,
|
||
|
|
)
|
||
|
|
return _kernel_module
|
||
|
|
|
||
|
|
|
||
|
|
def run_hash_router(
|
||
|
|
token_ids: torch.Tensor, # [N] int32
|
||
|
|
hash_lut: torch.Tensor, # [vocab_size, k] int32
|
||
|
|
top_k: int,
|
||
|
|
out_weights: torch.Tensor, # [N, k] float32, pre-allocated
|
||
|
|
out_ids: torch.Tensor, # [N, k] int32, pre-allocated
|
||
|
|
):
|
||
|
|
"""Run the hash router kernel: gather expert IDs from LUT, write 1/k weights."""
|
||
|
|
mod = _get_kernel_module()
|
||
|
|
return mod.hash_router(token_ids, hash_lut, top_k, out_weights, out_ids)
|