- compressor_reduce_quant.cu: Single-kernel CSA/HCA compress + RMSNorm + NVFP4 quantize. No intermediate BF16. FP32 → E2M1 + E4M3 + FP32 gsa in one kernel. Shared memory: ~2.5KB per CTA (FP32 staging + nibble buffer). - dequant_nvfp4.cu: NVFP4 → BF16 dequantization kernels. Full dequant (HCA dense gather) and selective dequant (CSA top-k gather). Single kernel launch per gather operation. - production_compress.py: Added csa_compress_production_nvfp4() and hca_compress_production_nvfp4() — production path for KV-1/KV-2. - loader.py: Preload dequant_nvfp4 and compressor_reduce_quant modules. - test_kv_compress_quant.py: Unit tests verifying cos >= 0.999 between BF16 reference and NVFP4 round-trip path.
82 lines
2.7 KiB
Python
82 lines
2.7 KiB
Python
"""CUDA kernel loader with compile-once caching.
|
|
|
|
Compiles .cu kernels on first call, caches the loaded module for subsequent calls.
|
|
Eliminates the JIT recompilation overhead from torch.utils.cpp_extension.load
|
|
being called on every kernel invocation (was ~100ms per call, called ~500x per token).
|
|
|
|
Usage:
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
|
result = mod.fused_amax_quantize_nvfp4(x, divisor)
|
|
"""
|
|
import os
|
|
import hashlib
|
|
import torch
|
|
from torch.utils.cpp_extension import load
|
|
|
|
_KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
_CACHE_DIR = os.path.join(_KERNEL_DIR, "_build_cache")
|
|
_LOADED_MODULES = {}
|
|
|
|
|
|
def get_cuda_module(name, sources, extra_cuda_cflags=None):
|
|
"""Load a CUDA kernel module, compiling once and caching forever.
|
|
|
|
Args:
|
|
name: Module name (used for caching key).
|
|
sources: List of .cu filenames relative to the kernels/cuda/ directory.
|
|
extra_cuda_cflags: Optional list of extra CUDA compiler flags.
|
|
|
|
Returns:
|
|
The loaded Python module with the kernel functions.
|
|
"""
|
|
if name in _LOADED_MODULES:
|
|
return _LOADED_MODULES[name]
|
|
|
|
source_paths = [os.path.join(_KERNEL_DIR, s) for s in sources]
|
|
|
|
# Build a cache key from source file contents + compile flags
|
|
hasher = hashlib.md5()
|
|
for sp in source_paths:
|
|
hasher.update(open(sp, 'rb').read())
|
|
cflags = extra_cuda_cflags or []
|
|
for cf in cflags:
|
|
hasher.update(cf.encode())
|
|
cache_key = f"{name}_{hasher.hexdigest()}"
|
|
|
|
# Ensure cache directory exists
|
|
os.makedirs(_CACHE_DIR, exist_ok=True)
|
|
|
|
cflags = cflags or [
|
|
"-gencode=arch=compute_100a,code=sm_100a",
|
|
"-O3",
|
|
"--use_fast_math",
|
|
]
|
|
|
|
mod = load(
|
|
name=cache_key,
|
|
sources=source_paths,
|
|
extra_cuda_cflags=cflags,
|
|
build_directory=_CACHE_DIR,
|
|
verbose=False,
|
|
)
|
|
|
|
_LOADED_MODULES[name] = mod
|
|
return mod
|
|
|
|
|
|
def preload_all():
|
|
"""Preload all CUDA kernels at startup (before the hot path)."""
|
|
# amax_gsa — computes gsa on GPU (no .item())
|
|
get_cuda_module("amax_gsa", ["amax_gsa.cu"])
|
|
# quantize-from-buffer — reads gsa from GPU buffer (no .item())
|
|
get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
|
|
# Standalone quantize (for when gsa is known, not hot path)
|
|
get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
|
|
# Sampler
|
|
get_cuda_module("sampler", ["sampler.cu"])
|
|
# Dequant NVFP4
|
|
get_cuda_module("dequant_nvfp4", ["dequant_nvfp4.cu"])
|
|
# Fused compress + quantize
|
|
get_cuda_module("compressor_reduce_quant", ["compressor_reduce_quant.cu"])
|