Files
grace-gpu-containers/vllm/sitecustomize.py
biondizzle aadde3ddf9 CMM: Fix OOM and subprocess crashes for GH200 EGM
Key changes:
- managed_alloc.cu: Add cudaMemPrefetchAsync to migrate pages to GPU
  immediately (prevents OOM from system RAM pinning on EGM systems
  where only ~102 GiB RAM remains). Add cudaMemAdviseSetAccessedBy
  for CPU so reads go over C2C NVLink without page migration.
- vllm_managed_mem.py: Rewrite with idempotent patches, proper
  MemorySnapshot.measure() override, and torch.cuda tracking stubs
  for CUDAPluggableAllocator compatibility.
- sitecustomize.py: Auto-loaded by Python in ALL subprocesses
  (including vLLM EngineCore). Applies allocator swap, torch patches,
  MemorySnapshot override, and request_memory override before any
  CUDA operations in spawned processes.
- Dockerfile: Install sitecustomize.py into Python dist-packages.
- README.md: Full rewrite with EGM problem statement, memory layout,
  architecture diagram, and build pipeline documentation.
2026-04-09 23:25:48 +00:00

115 lines
4.8 KiB
Python

"""
sitecustomize.py - Auto-loaded by Python before any other imports.
Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator.
Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot.
This MUST run before any torch.cuda calls in any subprocess.
Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode).
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
"""
import os
import sys
# Only activate in CMM mode
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if _MANAGED_TOTAL:
import torch
# Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops
_lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so')
if os.path.exists(_lib_path):
try:
import ctypes
lib = ctypes.CDLL(_lib_path)
if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'):
alloc = torch.cuda.memory.CUDAPluggableAllocator(
_lib_path, 'managed_malloc', 'managed_free'
)
torch.cuda.memory.change_current_allocator(alloc)
print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr)
except Exception as e:
print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr)
else:
print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr)
# Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support
_original_reset_peak = torch.cuda.reset_peak_memory_stats
_original_memory_stats = torch.cuda.memory_stats
_original_max_memory_allocated = torch.cuda.max_memory_allocated
def _patched_reset_peak_memory_stats(device=None):
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
pass
def _patched_memory_stats(device=None):
"""Return minimal stats dict to avoid crashes."""
try:
return _original_memory_stats(device)
except RuntimeError:
return {
"allocated_bytes.all.current": 0,
"allocated_bytes.all.peak": 0,
"reserved_bytes.all.current": 0,
"reserved_bytes.all.peak": 0,
"num_alloc_retries": 0,
"num_ooms": 0,
}
def _patched_max_memory_allocated(device=None):
"""Return 0 since we can't track peak with pluggable allocator."""
try:
return _original_max_memory_allocated(device)
except RuntimeError:
return 0
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
torch.cuda.memory_stats = _patched_memory_stats
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
# Patch accelerator aliases (PyTorch 2.11+)
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
if hasattr(torch.accelerator, 'memory_stats'):
torch.accelerator.memory_stats = _patched_memory_stats
if hasattr(torch.accelerator, 'max_memory_allocated'):
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
# Step 3: Patch MemorySnapshot.measure() to report managed memory
try:
from vllm.utils.mem_utils import MemorySnapshot
_original_measure = MemorySnapshot.measure
def _patched_measure(self):
_original_measure(self)
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
managed_total = int(float(managed_total_gb) * (1024**3))
self.total_memory = managed_total
self.free_memory = managed_total - self.cuda_memory
MemorySnapshot.measure = _patched_measure
except ImportError:
pass # vllm not loaded yet, will be patched by vllm_managed_mem.py
# Step 4: Patch request_memory to skip free-memory check for managed
try:
import vllm.v1.worker.utils as worker_utils
_original_request_memory = worker_utils.request_memory
def _patched_request_memory(init_snapshot, cache_config):
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
total_bytes = float(managed_total_gb) * (1024**3)
gpu_util = cache_config.gpu_memory_utilization
requested = int(total_bytes * gpu_util)
return requested
else:
return _original_request_memory(init_snapshot, cache_config)
worker_utils.request_memory = _patched_request_memory
except ImportError:
pass # vllm not loaded yet
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)