""" sitecustomize.py - Auto-loaded by Python before any other imports. Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator. Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot. This MUST run before any torch.cuda calls in any subprocess. Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode). Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py """ import os import sys # Only activate in CMM mode _MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB') if _MANAGED_TOTAL: import torch # Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops _lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so') if os.path.exists(_lib_path): try: import ctypes lib = ctypes.CDLL(_lib_path) if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'): alloc = torch.cuda.memory.CUDAPluggableAllocator( _lib_path, 'managed_malloc', 'managed_free' ) torch.cuda.memory.change_current_allocator(alloc) print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr) except Exception as e: print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr) else: print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr) # Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support _original_reset_peak = torch.cuda.reset_peak_memory_stats _original_memory_stats = torch.cuda.memory_stats _original_max_memory_allocated = torch.cuda.max_memory_allocated def _patched_reset_peak_memory_stats(device=None): """No-op: CUDAPluggableAllocator doesn't support resetPeakStats.""" pass def _patched_memory_stats(device=None): """Return minimal stats dict to avoid crashes.""" try: return _original_memory_stats(device) except RuntimeError: return { "allocated_bytes.all.current": 0, "allocated_bytes.all.peak": 0, "reserved_bytes.all.current": 0, "reserved_bytes.all.peak": 0, "num_alloc_retries": 0, "num_ooms": 0, } def _patched_max_memory_allocated(device=None): """Return 0 since we can't track peak with pluggable allocator.""" try: return _original_max_memory_allocated(device) except RuntimeError: return 0 torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats torch.cuda.memory_stats = _patched_memory_stats torch.cuda.max_memory_allocated = _patched_max_memory_allocated # Patch accelerator aliases (PyTorch 2.11+) if hasattr(torch.accelerator, 'reset_peak_memory_stats'): torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats if hasattr(torch.accelerator, 'memory_stats'): torch.accelerator.memory_stats = _patched_memory_stats if hasattr(torch.accelerator, 'max_memory_allocated'): torch.accelerator.max_memory_allocated = _patched_max_memory_allocated # Step 3: Patch MemorySnapshot.measure() to report managed memory try: from vllm.utils.mem_utils import MemorySnapshot _original_measure = MemorySnapshot.measure def _patched_measure(self): _original_measure(self) managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB') if managed_total_gb: managed_total = int(float(managed_total_gb) * (1024**3)) self.total_memory = managed_total self.free_memory = managed_total - self.cuda_memory MemorySnapshot.measure = _patched_measure except ImportError: pass # vllm not loaded yet, will be patched by vllm_managed_mem.py # Step 4: Patch request_memory to skip free-memory check for managed try: import vllm.v1.worker.utils as worker_utils _original_request_memory = worker_utils.request_memory def _patched_request_memory(init_snapshot, cache_config): managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB') if managed_total_gb: total_bytes = float(managed_total_gb) * (1024**3) gpu_util = cache_config.gpu_memory_utilization requested = int(total_bytes * gpu_util) return requested else: return _original_request_memory(init_snapshot, cache_config) worker_utils.request_memory = _patched_request_memory except ImportError: pass # vllm not loaded yet print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)