Key changes: - managed_alloc.cu: Add cudaMemPrefetchAsync to migrate pages to GPU immediately (prevents OOM from system RAM pinning on EGM systems where only ~102 GiB RAM remains). Add cudaMemAdviseSetAccessedBy for CPU so reads go over C2C NVLink without page migration. - vllm_managed_mem.py: Rewrite with idempotent patches, proper MemorySnapshot.measure() override, and torch.cuda tracking stubs for CUDAPluggableAllocator compatibility. - sitecustomize.py: Auto-loaded by Python in ALL subprocesses (including vLLM EngineCore). Applies allocator swap, torch patches, MemorySnapshot override, and request_memory override before any CUDA operations in spawned processes. - Dockerfile: Install sitecustomize.py into Python dist-packages. - README.md: Full rewrite with EGM problem statement, memory layout, architecture diagram, and build pipeline documentation.
115 lines
4.8 KiB
Python
115 lines
4.8 KiB
Python
"""
|
|
sitecustomize.py - Auto-loaded by Python before any other imports.
|
|
Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator.
|
|
Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot.
|
|
|
|
This MUST run before any torch.cuda calls in any subprocess.
|
|
Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode).
|
|
|
|
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
|
|
"""
|
|
import os
|
|
import sys
|
|
|
|
# Only activate in CMM mode
|
|
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
|
if _MANAGED_TOTAL:
|
|
import torch
|
|
|
|
# Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops
|
|
_lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so')
|
|
if os.path.exists(_lib_path):
|
|
try:
|
|
import ctypes
|
|
lib = ctypes.CDLL(_lib_path)
|
|
if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'):
|
|
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
|
_lib_path, 'managed_malloc', 'managed_free'
|
|
)
|
|
torch.cuda.memory.change_current_allocator(alloc)
|
|
print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
|
except Exception as e:
|
|
print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr)
|
|
else:
|
|
print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr)
|
|
|
|
# Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support
|
|
_original_reset_peak = torch.cuda.reset_peak_memory_stats
|
|
_original_memory_stats = torch.cuda.memory_stats
|
|
_original_max_memory_allocated = torch.cuda.max_memory_allocated
|
|
|
|
def _patched_reset_peak_memory_stats(device=None):
|
|
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
|
|
pass
|
|
|
|
def _patched_memory_stats(device=None):
|
|
"""Return minimal stats dict to avoid crashes."""
|
|
try:
|
|
return _original_memory_stats(device)
|
|
except RuntimeError:
|
|
return {
|
|
"allocated_bytes.all.current": 0,
|
|
"allocated_bytes.all.peak": 0,
|
|
"reserved_bytes.all.current": 0,
|
|
"reserved_bytes.all.peak": 0,
|
|
"num_alloc_retries": 0,
|
|
"num_ooms": 0,
|
|
}
|
|
|
|
def _patched_max_memory_allocated(device=None):
|
|
"""Return 0 since we can't track peak with pluggable allocator."""
|
|
try:
|
|
return _original_max_memory_allocated(device)
|
|
except RuntimeError:
|
|
return 0
|
|
|
|
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
|
torch.cuda.memory_stats = _patched_memory_stats
|
|
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
|
|
|
|
# Patch accelerator aliases (PyTorch 2.11+)
|
|
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
|
|
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
|
if hasattr(torch.accelerator, 'memory_stats'):
|
|
torch.accelerator.memory_stats = _patched_memory_stats
|
|
if hasattr(torch.accelerator, 'max_memory_allocated'):
|
|
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
|
|
|
|
# Step 3: Patch MemorySnapshot.measure() to report managed memory
|
|
try:
|
|
from vllm.utils.mem_utils import MemorySnapshot
|
|
_original_measure = MemorySnapshot.measure
|
|
|
|
def _patched_measure(self):
|
|
_original_measure(self)
|
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
|
if managed_total_gb:
|
|
managed_total = int(float(managed_total_gb) * (1024**3))
|
|
self.total_memory = managed_total
|
|
self.free_memory = managed_total - self.cuda_memory
|
|
|
|
MemorySnapshot.measure = _patched_measure
|
|
except ImportError:
|
|
pass # vllm not loaded yet, will be patched by vllm_managed_mem.py
|
|
|
|
# Step 4: Patch request_memory to skip free-memory check for managed
|
|
try:
|
|
import vllm.v1.worker.utils as worker_utils
|
|
_original_request_memory = worker_utils.request_memory
|
|
|
|
def _patched_request_memory(init_snapshot, cache_config):
|
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
|
if managed_total_gb:
|
|
total_bytes = float(managed_total_gb) * (1024**3)
|
|
gpu_util = cache_config.gpu_memory_utilization
|
|
requested = int(total_bytes * gpu_util)
|
|
return requested
|
|
else:
|
|
return _original_request_memory(init_snapshot, cache_config)
|
|
|
|
worker_utils.request_memory = _patched_request_memory
|
|
except ImportError:
|
|
pass # vllm not loaded yet
|
|
|
|
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)
|