CMM: Fix OOM and subprocess crashes for GH200 EGM
Key changes: - managed_alloc.cu: Add cudaMemPrefetchAsync to migrate pages to GPU immediately (prevents OOM from system RAM pinning on EGM systems where only ~102 GiB RAM remains). Add cudaMemAdviseSetAccessedBy for CPU so reads go over C2C NVLink without page migration. - vllm_managed_mem.py: Rewrite with idempotent patches, proper MemorySnapshot.measure() override, and torch.cuda tracking stubs for CUDAPluggableAllocator compatibility. - sitecustomize.py: Auto-loaded by Python in ALL subprocesses (including vLLM EngineCore). Applies allocator swap, torch patches, MemorySnapshot override, and request_memory override before any CUDA operations in spawned processes. - Dockerfile: Install sitecustomize.py into Python dist-packages. - README.md: Full rewrite with EGM problem statement, memory layout, architecture diagram, and build pipeline documentation.
This commit is contained in:
114
vllm/sitecustomize.py
Normal file
114
vllm/sitecustomize.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
sitecustomize.py - Auto-loaded by Python before any other imports.
|
||||
Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator.
|
||||
Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot.
|
||||
|
||||
This MUST run before any torch.cuda calls in any subprocess.
|
||||
Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode).
|
||||
|
||||
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Only activate in CMM mode
|
||||
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if _MANAGED_TOTAL:
|
||||
import torch
|
||||
|
||||
# Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops
|
||||
_lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so')
|
||||
if os.path.exists(_lib_path):
|
||||
try:
|
||||
import ctypes
|
||||
lib = ctypes.CDLL(_lib_path)
|
||||
if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'):
|
||||
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
_lib_path, 'managed_malloc', 'managed_free'
|
||||
)
|
||||
torch.cuda.memory.change_current_allocator(alloc)
|
||||
print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr)
|
||||
else:
|
||||
print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr)
|
||||
|
||||
# Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support
|
||||
_original_reset_peak = torch.cuda.reset_peak_memory_stats
|
||||
_original_memory_stats = torch.cuda.memory_stats
|
||||
_original_max_memory_allocated = torch.cuda.max_memory_allocated
|
||||
|
||||
def _patched_reset_peak_memory_stats(device=None):
|
||||
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
|
||||
pass
|
||||
|
||||
def _patched_memory_stats(device=None):
|
||||
"""Return minimal stats dict to avoid crashes."""
|
||||
try:
|
||||
return _original_memory_stats(device)
|
||||
except RuntimeError:
|
||||
return {
|
||||
"allocated_bytes.all.current": 0,
|
||||
"allocated_bytes.all.peak": 0,
|
||||
"reserved_bytes.all.current": 0,
|
||||
"reserved_bytes.all.peak": 0,
|
||||
"num_alloc_retries": 0,
|
||||
"num_ooms": 0,
|
||||
}
|
||||
|
||||
def _patched_max_memory_allocated(device=None):
|
||||
"""Return 0 since we can't track peak with pluggable allocator."""
|
||||
try:
|
||||
return _original_max_memory_allocated(device)
|
||||
except RuntimeError:
|
||||
return 0
|
||||
|
||||
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||
torch.cuda.memory_stats = _patched_memory_stats
|
||||
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
|
||||
|
||||
# Patch accelerator aliases (PyTorch 2.11+)
|
||||
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
|
||||
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||
if hasattr(torch.accelerator, 'memory_stats'):
|
||||
torch.accelerator.memory_stats = _patched_memory_stats
|
||||
if hasattr(torch.accelerator, 'max_memory_allocated'):
|
||||
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
|
||||
|
||||
# Step 3: Patch MemorySnapshot.measure() to report managed memory
|
||||
try:
|
||||
from vllm.utils.mem_utils import MemorySnapshot
|
||||
_original_measure = MemorySnapshot.measure
|
||||
|
||||
def _patched_measure(self):
|
||||
_original_measure(self)
|
||||
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if managed_total_gb:
|
||||
managed_total = int(float(managed_total_gb) * (1024**3))
|
||||
self.total_memory = managed_total
|
||||
self.free_memory = managed_total - self.cuda_memory
|
||||
|
||||
MemorySnapshot.measure = _patched_measure
|
||||
except ImportError:
|
||||
pass # vllm not loaded yet, will be patched by vllm_managed_mem.py
|
||||
|
||||
# Step 4: Patch request_memory to skip free-memory check for managed
|
||||
try:
|
||||
import vllm.v1.worker.utils as worker_utils
|
||||
_original_request_memory = worker_utils.request_memory
|
||||
|
||||
def _patched_request_memory(init_snapshot, cache_config):
|
||||
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if managed_total_gb:
|
||||
total_bytes = float(managed_total_gb) * (1024**3)
|
||||
gpu_util = cache_config.gpu_memory_utilization
|
||||
requested = int(total_bytes * gpu_util)
|
||||
return requested
|
||||
else:
|
||||
return _original_request_memory(init_snapshot, cache_config)
|
||||
|
||||
worker_utils.request_memory = _patched_request_memory
|
||||
except ImportError:
|
||||
pass # vllm not loaded yet
|
||||
|
||||
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)
|
||||
Reference in New Issue
Block a user