diff --git a/dsv4/kernels/cuda/loader.py b/dsv4/kernels/cuda/loader.py index f5c380bb..6d0295e3 100644 --- a/dsv4/kernels/cuda/loader.py +++ b/dsv4/kernels/cuda/loader.py @@ -10,6 +10,7 @@ Usage: result = mod.quantize_nvfp4_from_buffer(x, divisor) """ import os +import time import hashlib import torch from torch.utils.cpp_extension import load @@ -18,6 +19,34 @@ _KERNEL_DIR = os.path.dirname(os.path.abspath(__file__)) _CACHE_DIR = os.path.join(_KERNEL_DIR, "_build_cache") _LOADED_MODULES = {} +# Maximum age of a stale lock file before we remove it (seconds). +# torch.utils.cpp_extension.load creates a lock file during compilation. +# If the process is killed during compilation, the lock remains and the +# next process spins forever polling it. This timeout prevents that. +_STALE_LOCK_TIMEOUT_S = 600 # 10 minutes + + +def _cleanup_stale_lock(): + """Remove stale lock files from the build cache directory. + + torch.utils.cpp_extension.load creates a 'lock' file in the build + directory during compilation. If the compiling process is killed + (OOM, timeout, user interrupt), the lock file is never removed and + subsequent processes spin forever waiting for it. + + This function checks if a lock file exists and is older than + _STALE_LOCK_TIMEOUT_S. If so, it removes it. + """ + lock_path = os.path.join(_CACHE_DIR, "lock") + if os.path.exists(lock_path): + try: + lock_age = time.time() - os.path.getmtime(lock_path) + if lock_age > _STALE_LOCK_TIMEOUT_S: + os.remove(lock_path) + print(f"[loader] Removed stale lock file (age={lock_age:.0f}s)", flush=True) + except OSError: + pass # Lock was removed between exists() and remove() + def get_cuda_module(name, sources, extra_cuda_cflags=None): """Load a CUDA kernel module, compiling once and caching forever. @@ -33,6 +62,9 @@ def get_cuda_module(name, sources, extra_cuda_cflags=None): if name in _LOADED_MODULES: return _LOADED_MODULES[name] + # Clean up stale lock files from crashed previous compilations + _cleanup_stale_lock() + source_paths = [os.path.join(_KERNEL_DIR, s) for s in sources] # Build a cache key from source file contents + compile flags