CMM: Fix OOM and subprocess crashes for GH200 EGM
Key changes: - managed_alloc.cu: Add cudaMemPrefetchAsync to migrate pages to GPU immediately (prevents OOM from system RAM pinning on EGM systems where only ~102 GiB RAM remains). Add cudaMemAdviseSetAccessedBy for CPU so reads go over C2C NVLink without page migration. - vllm_managed_mem.py: Rewrite with idempotent patches, proper MemorySnapshot.measure() override, and torch.cuda tracking stubs for CUDAPluggableAllocator compatibility. - sitecustomize.py: Auto-loaded by Python in ALL subprocesses (including vLLM EngineCore). Applies allocator swap, torch patches, MemorySnapshot override, and request_memory override before any CUDA operations in spawned processes. - Dockerfile: Install sitecustomize.py into Python dist-packages. - README.md: Full rewrite with EGM problem statement, memory layout, architecture diagram, and build pipeline documentation.
This commit is contained in:
@@ -19,11 +19,20 @@ import os
|
||||
import sys
|
||||
import ctypes
|
||||
|
||||
|
||||
def get_total_managed_memory_gb():
|
||||
"""
|
||||
Calculate total memory available via managed allocations on GH200.
|
||||
Parses /proc/iomem to find NVIDIA EGM regions + HBM.
|
||||
First checks MANAGED_MEMORY_TOTAL_GB env var (explicit override).
|
||||
Then tries /proc/iomem (requires --privileged container).
|
||||
Falls back to HBM only if neither works.
|
||||
"""
|
||||
# Explicit override takes priority
|
||||
env_val = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if env_val:
|
||||
return float(env_val)
|
||||
|
||||
# Try /proc/iomem for EGM regions (requires --privileged or host mount)
|
||||
egm_bytes = 0
|
||||
try:
|
||||
with open('/proc/iomem', 'r') as f:
|
||||
@@ -34,19 +43,26 @@ def get_total_managed_memory_gb():
|
||||
start = int(start_s, 16)
|
||||
end = int(end_s, 16)
|
||||
egm_bytes += (end - start + 1)
|
||||
except (PermissionError, FileNotFoundError):
|
||||
except (PermissionError, FileNotFoundError, ValueError):
|
||||
pass
|
||||
|
||||
egm_gb = egm_bytes / (1024**3)
|
||||
# HBM is always there via normal cudaMalloc path
|
||||
# cudaMallocManaged can span both HBM + EGM
|
||||
return egm_gb
|
||||
if egm_gb > 0:
|
||||
# EGM detected via iomem, add HBM
|
||||
import torch
|
||||
_, hbm_total = torch.cuda.mem_get_info(0)
|
||||
hbm_gb = hbm_total / (1024**3)
|
||||
return egm_gb + hbm_gb
|
||||
|
||||
# No EGM detected - return 0 (caller should handle)
|
||||
return 0
|
||||
|
||||
|
||||
def swap_allocator():
|
||||
"""
|
||||
Replace PyTorch's default CUDA allocator with our managed memory allocator.
|
||||
This MUST happen before any CUDA tensors are created.
|
||||
If sitecustomize.py already swapped it, this is a no-op.
|
||||
"""
|
||||
lib_path = os.environ.get(
|
||||
'MANAGED_ALLOC_LIB',
|
||||
@@ -64,43 +80,125 @@ def swap_allocator():
|
||||
lib = ctypes.CDLL(lib_path)
|
||||
assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found"
|
||||
assert hasattr(lib, 'managed_free'), "managed_free symbol not found"
|
||||
print(f"[managed_mem] Loaded allocator from {lib_path}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
import torch
|
||||
# This must happen before ANY cuda operation
|
||||
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
lib_path, 'managed_malloc', 'managed_free'
|
||||
)
|
||||
torch.cuda.memory.change_current_allocator(alloc)
|
||||
try:
|
||||
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
lib_path, 'managed_malloc', 'managed_free'
|
||||
)
|
||||
torch.cuda.memory.change_current_allocator(alloc)
|
||||
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
||||
except RuntimeError as e:
|
||||
if "already initialized" in str(e):
|
||||
print(f"[managed_mem] Allocator already initialized (sitecustomize)",
|
||||
file=sys.stderr)
|
||||
else:
|
||||
raise
|
||||
|
||||
egm_gb = get_total_managed_memory_gb()
|
||||
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
||||
print(f"[managed_mem] Detected ~{egm_gb:.0f} GiB EGM memory", file=sys.stderr)
|
||||
print(f"[managed_mem] Total addressable: ~{egm_gb + 96:.0f} GiB "
|
||||
f"(EGM + HBM)", file=sys.stderr)
|
||||
|
||||
def patch_memory_snapshot():
|
||||
"""
|
||||
Patch MemorySnapshot.measure() to report the full managed memory
|
||||
instead of just HBM.
|
||||
|
||||
This is the core fix: cudaMemGetInfo only reports HBM (~96 GiB),
|
||||
but with cudaMallocManaged we can address HBM + EGM (~474 GiB).
|
||||
We override the snapshot so all downstream code (request_memory,
|
||||
KV cache sizing, etc.) sees the full managed memory capacity.
|
||||
"""
|
||||
from vllm.utils.mem_utils import MemorySnapshot
|
||||
|
||||
_original_measure = MemorySnapshot.measure
|
||||
|
||||
def patched_measure(self):
|
||||
_original_measure(self)
|
||||
# Override total_memory with managed memory capacity
|
||||
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if managed_total_gb:
|
||||
managed_total = int(float(managed_total_gb) * (1024**3))
|
||||
self.total_memory = managed_total
|
||||
# free_memory = total - what CUDA is using
|
||||
self.free_memory = managed_total - self.cuda_memory
|
||||
print(f"[managed_mem] MemorySnapshot patched: "
|
||||
f"total={self.total_memory / (1024**3):.0f} GiB, "
|
||||
f"free={self.free_memory / (1024**3):.0f} GiB",
|
||||
file=sys.stderr)
|
||||
|
||||
MemorySnapshot.measure = patched_measure
|
||||
print(f"[managed_mem] Patched MemorySnapshot.measure()", file=sys.stderr)
|
||||
|
||||
|
||||
def patch_torch_memory_tracking():
|
||||
"""
|
||||
Patch PyTorch memory tracking functions that CUDAPluggableAllocator
|
||||
doesn't support. If sitecustomize already patched them, this is a no-op.
|
||||
"""
|
||||
import torch
|
||||
|
||||
# Check if already patched by sitecustomize
|
||||
if torch.cuda.reset_peak_memory_stats.__name__ == '_patched_reset_peak_memory_stats':
|
||||
print(f"[managed_mem] torch.cuda already patched (sitecustomize)", file=sys.stderr)
|
||||
return
|
||||
|
||||
_original_reset_peak = torch.cuda.reset_peak_memory_stats
|
||||
_original_memory_stats = torch.cuda.memory_stats
|
||||
_original_max_memory_allocated = torch.cuda.max_memory_allocated
|
||||
|
||||
def _patched_reset_peak_memory_stats(device=None):
|
||||
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
|
||||
pass
|
||||
|
||||
def _patched_memory_stats(device=None):
|
||||
"""Return minimal stats dict to avoid crashes."""
|
||||
try:
|
||||
return _original_memory_stats(device)
|
||||
except RuntimeError:
|
||||
return {
|
||||
"allocated_bytes.all.current": 0,
|
||||
"allocated_bytes.all.peak": 0,
|
||||
"reserved_bytes.all.current": 0,
|
||||
"reserved_bytes.all.peak": 0,
|
||||
"num_alloc_retries": 0,
|
||||
"num_ooms": 0,
|
||||
}
|
||||
|
||||
def _patched_max_memory_allocated(device=None):
|
||||
"""Return 0 since we can't track peak with pluggable allocator."""
|
||||
try:
|
||||
return _original_max_memory_allocated(device)
|
||||
except RuntimeError:
|
||||
return 0
|
||||
|
||||
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||
torch.cuda.memory_stats = _patched_memory_stats
|
||||
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
|
||||
|
||||
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
|
||||
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||
if hasattr(torch.accelerator, 'memory_stats'):
|
||||
torch.accelerator.memory_stats = _patched_memory_stats
|
||||
if hasattr(torch.accelerator, 'max_memory_allocated'):
|
||||
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
|
||||
|
||||
print(f"[managed_mem] Patched torch.cuda memory tracking (no-op stubs)",
|
||||
file=sys.stderr)
|
||||
|
||||
|
||||
def patch_vllm_memory_check():
|
||||
"""
|
||||
Monkey-patch vLLM's memory validation to understand managed memory.
|
||||
|
||||
vLLM checks cudaMemGetInfo and rejects startup if free < requested.
|
||||
With managed memory, the real capacity is much larger than what
|
||||
cudaMemGetInfo reports (it only sees HBM + a tiny EGM sliver).
|
||||
With managed memory, cudaMemGetInfo only reports HBM, so the free-memory
|
||||
check would fail. When MANAGED_MEMORY_TOTAL_GB is set, we skip that check.
|
||||
"""
|
||||
import vllm.v1.worker.utils as worker_utils
|
||||
|
||||
_original_request_memory = worker_utils.request_memory
|
||||
|
||||
def patched_request_memory(init_snapshot, cache_config):
|
||||
"""
|
||||
Override memory validation:
|
||||
- Read the MANAGED_MEMORY_TOTAL_GB env var (set by us)
|
||||
- If set, skip the free-memory check and just return the requested size
|
||||
"""
|
||||
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if managed_total_gb:
|
||||
total_bytes = float(managed_total_gb) * (1024**3)
|
||||
@@ -114,40 +212,7 @@ def patch_vllm_memory_check():
|
||||
return _original_request_memory(init_snapshot, cache_config)
|
||||
|
||||
worker_utils.request_memory = patched_request_memory
|
||||
print(f"[managed_mem] Patched vLLM memory check", file=sys.stderr)
|
||||
|
||||
|
||||
def patch_vllm_memory_snapshot():
|
||||
"""
|
||||
Patch the memory snapshot to report managed-aware totals.
|
||||
Without this, vLLM's worker sees ~95 GiB total and panics.
|
||||
"""
|
||||
import vllm.v1.worker.gpu_worker as gpu_worker
|
||||
|
||||
_original_init_device = gpu_worker.GpuWorker.init_device
|
||||
|
||||
def patched_init_device(self):
|
||||
"""Patch init_device to override memory reporting."""
|
||||
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if managed_total_gb:
|
||||
# Set a permissive gpu_memory_utilization on the cache_config
|
||||
# so the original code doesn't reject us
|
||||
original_util = self.cache_config.gpu_memory_utilization
|
||||
self.cache_config.gpu_memory_utilization = 0.5 # always passes
|
||||
|
||||
_original_init_device(self)
|
||||
|
||||
# Now fix up the requested_memory with the real value
|
||||
total_bytes = float(managed_total_gb) * (1024**3)
|
||||
self.requested_memory = int(total_bytes * original_util)
|
||||
self.cache_config.gpu_memory_utilization = original_util
|
||||
print(f"[managed_mem] Worker memory overridden to "
|
||||
f"{self.requested_memory / (1024**3):.1f} GiB", file=sys.stderr)
|
||||
else:
|
||||
_original_init_device(self)
|
||||
|
||||
gpu_worker.GpuWorker.init_device = patched_init_device
|
||||
print(f"[managed_mem] Patched vLLM worker init_device", file=sys.stderr)
|
||||
print(f"[managed_mem] Patched vLLM request_memory", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -155,17 +220,28 @@ def main():
|
||||
swap_allocator()
|
||||
|
||||
# Step 2: Calculate total managed memory and export it
|
||||
egm_gb = get_total_managed_memory_gb()
|
||||
total_managed_gb = egm_gb + 96 # EGM + HBM
|
||||
total_managed_gb = get_total_managed_memory_gb()
|
||||
if total_managed_gb <= 0:
|
||||
print(f"[managed_mem] WARNING: No managed memory detected! "
|
||||
f"Set MANAGED_MEMORY_TOTAL_GB env var explicitly.",
|
||||
file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb)
|
||||
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Step 3: Patch vLLM memory checks
|
||||
patch_vllm_memory_check()
|
||||
patch_vllm_memory_snapshot()
|
||||
# Step 3: Patch PyTorch memory tracking (pluggable allocator doesn't support all ops)
|
||||
patch_torch_memory_tracking()
|
||||
|
||||
# Step 4: Launch vLLM's API server with remaining args
|
||||
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
|
||||
# This is critical - without it, all downstream code only sees HBM
|
||||
patch_memory_snapshot()
|
||||
|
||||
# Step 5: Patch request_memory as a safety net
|
||||
patch_vllm_memory_check()
|
||||
|
||||
# Step 6: Launch vLLM's API server with remaining args
|
||||
sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:]
|
||||
print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}",
|
||||
file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user