Files
nvfp4-megamoe-kernel/dsv4/kernels/cuda/loader.py
biondizzle cacf64232e CRITICAL FIX: fused_amax_quantize cross-CTA race condition
The single-kernel approach used __syncthreads() for cross-CTA amax
reduction, but __syncthreads() only syncs within a CTA (same blockIdx).
CTA 0 reading s_amax[1] before CTA 1 writes = race condition = garbage gsa.

Result: residual |X| exploded to 10^37 by L0. F_attn and F_ffn were 0.0.

Fix: Two-kernel approach (correct, zero CPU syncs):
  Kernel 1: amax_gsa.cu — computes gsa on GPU, returns GPU tensor
  Kernel 2: quantize_nvfp4_from_buffer — reads gsa from GPU buffer

The fused_amax_quantize.cu now exports quantize_nvfp4_from_buffer and
deinterleave_quantize_from_buffer (gsa from GPU buffer, not kernel param).

Same P0 win: zero .item() syncs. Two kernel launches instead of one,
but correctness > shaving one launch.
2026-06-01 21:26:51 +00:00

78 lines
2.5 KiB
Python

"""CUDA kernel loader with compile-once caching.
Compiles .cu kernels on first call, caches the loaded module for subsequent calls.
Eliminates the JIT recompilation overhead from torch.utils.cpp_extension.load
being called on every kernel invocation (was ~100ms per call, called ~500x per token).
Usage:
from dsv4.kernels.cuda.loader import get_cuda_module
mod = get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
result = mod.fused_amax_quantize_nvfp4(x, divisor)
"""
import os
import hashlib
import torch
from torch.utils.cpp_extension import load
_KERNEL_DIR = os.path.dirname(os.path.abspath(__file__))
_CACHE_DIR = os.path.join(_KERNEL_DIR, "_build_cache")
_LOADED_MODULES = {}
def get_cuda_module(name, sources, extra_cuda_cflags=None):
"""Load a CUDA kernel module, compiling once and caching forever.
Args:
name: Module name (used for caching key).
sources: List of .cu filenames relative to the kernels/cuda/ directory.
extra_cuda_cflags: Optional list of extra CUDA compiler flags.
Returns:
The loaded Python module with the kernel functions.
"""
if name in _LOADED_MODULES:
return _LOADED_MODULES[name]
source_paths = [os.path.join(_KERNEL_DIR, s) for s in sources]
# Build a cache key from source file contents + compile flags
hasher = hashlib.md5()
for sp in source_paths:
hasher.update(open(sp, 'rb').read())
cflags = extra_cuda_cflags or []
for cf in cflags:
hasher.update(cf.encode())
cache_key = f"{name}_{hasher.hexdigest()}"
# Ensure cache directory exists
os.makedirs(_CACHE_DIR, exist_ok=True)
cflags = cflags or [
"-gencode=arch=compute_100a,code=sm_100a",
"-O3",
"--use_fast_math",
]
mod = load(
name=cache_key,
sources=source_paths,
extra_cuda_cflags=cflags,
build_directory=_CACHE_DIR,
verbose=False,
)
_LOADED_MODULES[name] = mod
return mod
def preload_all():
"""Preload all CUDA kernels at startup (before the hot path)."""
# amax_gsa — computes gsa on GPU (no .item())
get_cuda_module("amax_gsa", ["amax_gsa.cu"])
# quantize-from-buffer — reads gsa from GPU buffer (no .item())
get_cuda_module("fused_amax_quantize", ["fused_amax_quantize.cu"])
# Standalone quantize (for when gsa is known, not hot path)
get_cuda_module("quantize_nvfp4", ["quantize_nvfp4.cu"])
# Sampler
get_cuda_module("sampler", ["sampler.cu"])