diff --git a/dsv4/kernels/cuda/__init__.py b/dsv4/kernels/cuda/__init__.py index d4f9bdf9..9b48177f 100644 --- a/dsv4/kernels/cuda/__init__.py +++ b/dsv4/kernels/cuda/__init__.py @@ -1,75 +1,2 @@ -"""CUDA kernel loader with compile-once caching. - -Compiles .cu kernels on first call, caches the loaded module for subsequent calls. -Eliminates the JIT recompilation overhead from torch.utils.cpp_extension.load -being called on every kernel invocation (was ~100ms per call, called ~500x per token). - -Usage: - from dsv4.kernels.cuda.loader import get_cuda_module - mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) - result = mod.fused_amax_quantize_nvfp4(x, divisor) -""" -import os -import hashlib -import torch -from torch.utils.cpp_extension import load - -_KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) -_CACHE_DIR = os.path.join(_KERNEL_DIR, "_build_cache") -_LOADED_MODULES = {} - - -def get_cuda_module(name, sources, extra_cuda_cflags=None): - """Load a CUDA kernel module, compiling once and caching forever. - - Args: - name: Module name (used for caching key). - sources: List of .cu filenames relative to the kernels/cuda/ directory. - extra_cuda_cflags: Optional list of extra CUDA compiler flags. - - Returns: - The loaded Python module with the kernel functions. - """ - if name in _LOADED_MODULES: - return _LOADED_MODULES[name] - - source_paths = [os.path.join(_KERNEL_DIR, s) for s in sources] - - # Build a cache key from source file contents + compile flags - hasher = hashlib.md5() - for sp in source_paths: - hasher.update(open(sp, 'rb').read()) - cflags = extra_cuda_cflags or [] - for cf in cflags: - hasher.update(cf.encode()) - cache_key = f"{name}_{hasher.hexdigest()}" - - # Ensure cache directory exists - os.makedirs(_CACHE_DIR, exist_ok=True) - - cflags = cflags or [ - "-gencode=arch=compute_100a,code=sm_100a", - "-O3", - "--use_fast_math", - ] - - mod = load( - name=cache_key, - sources=source_paths, - extra_cuda_cflags=cflags, - build_directory=_CACHE_DIR, - verbose=False, - ) - - _LOADED_MODULES[name] = mod - return mod - - -def preload_all(): - """Preload all CUDA kernels at startup (before the hot path).""" - # Fused amax + quantize — THE critical kernel for P0 - get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) - # Standalone quantize (used by weight quantization, not hot path) - get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"]) - # Sampler - get_cuda_module("sampler", ["sampler.cu"]) +"""CUDA kernel loader — re-exports from loader.py for convenience.""" +from dsv4.kernels.cuda.loader import get_cuda_module, preload_all diff --git a/dsv4/kernels/cuda/loader.py b/dsv4/kernels/cuda/loader.py new file mode 100644 index 00000000..d4f9bdf9 --- /dev/null +++ b/dsv4/kernels/cuda/loader.py @@ -0,0 +1,75 @@ +"""CUDA kernel loader with compile-once caching. + +Compiles .cu kernels on first call, caches the loaded module for subsequent calls. +Eliminates the JIT recompilation overhead from torch.utils.cpp_extension.load +being called on every kernel invocation (was ~100ms per call, called ~500x per token). + +Usage: + from dsv4.kernels.cuda.loader import get_cuda_module + mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) + result = mod.fused_amax_quantize_nvfp4(x, divisor) +""" +import os +import hashlib +import torch +from torch.utils.cpp_extension import load + +_KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) +_CACHE_DIR = os.path.join(_KERNEL_DIR, "_build_cache") +_LOADED_MODULES = {} + + +def get_cuda_module(name, sources, extra_cuda_cflags=None): + """Load a CUDA kernel module, compiling once and caching forever. + + Args: + name: Module name (used for caching key). + sources: List of .cu filenames relative to the kernels/cuda/ directory. + extra_cuda_cflags: Optional list of extra CUDA compiler flags. + + Returns: + The loaded Python module with the kernel functions. + """ + if name in _LOADED_MODULES: + return _LOADED_MODULES[name] + + source_paths = [os.path.join(_KERNEL_DIR, s) for s in sources] + + # Build a cache key from source file contents + compile flags + hasher = hashlib.md5() + for sp in source_paths: + hasher.update(open(sp, 'rb').read()) + cflags = extra_cuda_cflags or [] + for cf in cflags: + hasher.update(cf.encode()) + cache_key = f"{name}_{hasher.hexdigest()}" + + # Ensure cache directory exists + os.makedirs(_CACHE_DIR, exist_ok=True) + + cflags = cflags or [ + "-gencode=arch=compute_100a,code=sm_100a", + "-O3", + "--use_fast_math", + ] + + mod = load( + name=cache_key, + sources=source_paths, + extra_cuda_cflags=cflags, + build_directory=_CACHE_DIR, + verbose=False, + ) + + _LOADED_MODULES[name] = mod + return mod + + +def preload_all(): + """Preload all CUDA kernels at startup (before the hot path).""" + # Fused amax + quantize — THE critical kernel for P0 + get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"]) + # Standalone quantize (used by weight quantization, not hot path) + get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"]) + # Sampler + get_cuda_module("sampler", ["sampler.cu"])