CMM: Fix OOM and subprocess crashes for GH200 EGM

Key changes:
- managed_alloc.cu: Add cudaMemPrefetchAsync to migrate pages to GPU
  immediately (prevents OOM from system RAM pinning on EGM systems
  where only ~102 GiB RAM remains). Add cudaMemAdviseSetAccessedBy
  for CPU so reads go over C2C NVLink without page migration.
- vllm_managed_mem.py: Rewrite with idempotent patches, proper
  MemorySnapshot.measure() override, and torch.cuda tracking stubs
  for CUDAPluggableAllocator compatibility.
- sitecustomize.py: Auto-loaded by Python in ALL subprocesses
  (including vLLM EngineCore). Applies allocator swap, torch patches,
  MemorySnapshot override, and request_memory override before any
  CUDA operations in spawned processes.
- Dockerfile: Install sitecustomize.py into Python dist-packages.
- README.md: Full rewrite with EGM problem statement, memory layout,
  architecture diagram, and build pipeline documentation.
This commit is contained in:
2026-04-09 23:25:48 +00:00
parent 079eb88d7d
commit aadde3ddf9
5 changed files with 460 additions and 96 deletions

163
README.md
View File

@@ -1,33 +1,152 @@
# Building containers for GH200
# Building vLLM containers for GH200 leveraging cudaMallocManaged when EGM (Extended GPU Memory) is enabled
Currently, prebuilt wheels for `vLLM` and `LMcache` are not available for `aarch64`. This can make setup tedious when working on modern `aarch64` platforms such as NVIDIA GH200.
## The Problem
Further, Nvidia at this time does not provide the `Dockerfile` associated with the NGC containers which makes replacing some of the components (like a newer version of vLLM) tedious.
The GH200 has 96 GiB of HBM (VRAM) and 480 GiB of LPDDR (system memory). The only way for the GPU to access system memory over the C2C NVLink at the full 900 GB/s — without going through the IOMMU — is to enable EGM in the BIOS.
This repository provides a Dockerfile to build a container with vLLM and all its dependencies pre-installed to try out various things such as KV offloading.
This creates three issues:
If you prefer not to build the image yourself, you can pull the ready-to-use image directly from Docker Hub:
1. **EGM requires reserved memory value of 8192** — this is the only value that works
2. **The server loses most of its system memory** — no longer sees 480 GiB; instead sees ~102 GiB. The rest has been handed over to the GPU
3. **vLLM still only sees 96 GiB of VRAM** — the GPU now has access to the full memory space, but the only way to leverage it is to convert all `cudaMalloc` calls to `cudaMallocManaged`. Without this, vLLM's allocator only touches HBM
```bash
docker run --rm -it --gpus all -v "$PWD":"$PWD" -w "$PWD" rajesh550/gh200-vllm:0.11.0 bash
## The Goal
# CUDA 13
docker run --rm -it --gpus all -v "$PWD":"$PWD" -w "$PWD" rajesh550/gh200-vllm:0.11.1rc2 bash
Force vLLM to use `cudaMallocManaged` so it can address the full memory space (HBM + EGM). Before we can do that, we have to make sure vLLM's preflight checks of available VRAM show the new fully allocated amount (~97 GiB HBM + ~378 GiB EGM = ~475 GiB total managed memory).
## GH200 System State with EGM Enabled
```
$ nvidia-smi
+-----------------------------------------------------------------------------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
|=========================================================================================|
| 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | Off |
| N/A 31C P0 83W / 700W | 14920MiB / 97871MiB | 0% Default |
+-----------------------------------------------------------------------------------------+
$ free -h
total used free shared buff/cache available
Mem: 102Gi 4.2Gi 83Gi 10Mi 15Gi 93Gi
$ dmidecode -t memory | grep -i size
Size: 480 GB
$ numactl --hardware
available: 10 nodes (0-9)
node 0 cpus: 0-71
node 0 size: 8080 MB
node 2 size: 97280 MB
```
👉 [Docker Hub](https://hub.docker.com/repository/docker/rajesh550/gh200-vllm/general)
Key observations:
- **nvidia-smi** reports 97,871 MiB (~96 GiB) — this is HBM only, NOT the EGM
- **OS** sees only ~102 GiB system RAM — EGM carved out 378 GiB for the GPU
- **dmidecode** confirms the physical DIMM is 480 GiB
- **NUMA** node 2 has 97 GiB (the LPDDR that wasn't handed to EGM), node 0 has 8 GiB (local)
- The EGM memory appears as `System RAM (NVIDIA)` in `/proc/iomem` at addresses `0x400000000000+`
Version info:
## Architecture
### Managed Memory Allocator
The approach uses a PyTorch pluggable allocator (`managed_alloc.cu`) that replaces `cudaMalloc` with `cudaMallocManaged`, enabling transparent page-fault access to both HBM and EGM.
```
┌─────────────────────────────────────────────────┐
│ vLLM │
│ (sees full managed memory) │
├─────────────────────────────────────────────────┤
│ PyTorch Pluggable Allocator │
│ (managed_alloc.cu / .so) │
│ cudaMallocManaged → unified memory pool │
├─────────────────────────────────────────────────┤
│ CUDA Unified Memory │
├──────────────────┬──────────────────────────────┤
│ HBM (~96 GiB) │ EGM (~378 GiB) │
│ Fast / local │ Page-fault over C2C NVLink │
│ │ 900 GB/s bandwidth │
└──────────────────┴──────────────────────────────┘
```
### Launcher
`vllm_managed_mem.py` is the entry point that:
1. Loads the managed allocator `.so` before any CUDA operations
2. Swaps PyTorch's default allocator to `cudaMallocManaged`
3. Patches vLLM's memory validation to understand the larger managed memory space
## Source Repos
| Repo | URL | Branch | Purpose |
|------|-----|--------|---------|
| vLLM (our fork) | `https://sweetapi.com/biondizzle/vllm.git` | `cmm` | vLLM with cudaMallocManaged patches |
| grace-gpu-containers | `https://sweetapi.com/biondizzle/grace-gpu-containers.git` | `cuda-malloc-managed` | Dockerfile + build pipeline |
The `cmm` branch in our vLLM fork is based on tag `v0.19.0` (commit `2a69949bd`).
## Build Pipeline
### Jenkins: `gh200-vllm-build-cmm`
The build chain:
```
Jenkins → grace-gpu-containers (cuda-malloc-managed branch)
→ Dockerfile clones vLLM from Gitea fork (cmm branch)
→ Builds ARM64 image via buildx on remote GH200
→ Pushes to atl.vultrcr.com/vllm
```
**Default parameters:**
- `VLLM_VERSION`: `cmm` (our fork branch)
- `IMAGE_TAG`: `gh200-vllm-cmm`
- `IMAGE_SUFFIX`: `-cmm`
**Image:** `atl.vultrcr.com/vllm/gh200-vllm-cmm:cmm-cmm`
### Dockerfile Build Stages
| Stage | What it builds | Source |
|-------|---------------|--------|
| build-triton | Triton 3.6.0 | PyPI wheel (aarch64) |
| build-triton-kernels | triton_kernels v3.6.0 | Triton repo |
| build-flashinfer | flashinfer v0.6.6 | Source (apache-tvm-ffi required) |
| build-lmcache | LMCache dev | Source |
| build-flash-attention | FlashAttention hopper | Source |
| build-vllm | vLLM cmm branch | Our Gitea fork |
| build-infinistore | InfiniStore | Source |
**Base image:** `nvcr.io/nvidia/pytorch:26.03-py3`
- PyTorch 2.11.0a0, CUDA 13.2.0
- Multi-arch: x86 + ARM SBSA (GH200)
- Target: `9.0a` (Hopper)
## Key Files
| File | Description |
|------|-------------|
| `vllm/Dockerfile` | Multi-stage build for the CMM container |
| `vllm/managed_alloc.cu` | PyTorch pluggable allocator using `cudaMallocManaged` |
| `vllm/vllm_managed_mem.py` | Launcher that patches vLLM for managed memory |
| `lmcache/Dockerfile` | Standalone LMCache build |
## Local Development
```bash
CUDA: 13.0.1
Ubuntu: 24.04
Python: 3.12
PyTorch: 2.9.0+cu130
Triton: 3.5.x
xformers: 0.32.post2+
flashinfer: 0.4.1
flashattention: 3.0.0b1
LMCache: 0.3.7
vLLM: 0.11.1rc3
```
# vLLM fork (working directory)
cd /home/openclaw/dev/vllm
git checkout cmm # our working branch, based on v0.19.0
# grace-gpu-containers
cd /home/openclaw/dev/grace-gpu-containers
git checkout cuda-malloc-managed # CMM Dockerfile lives here
```
## Hard Rules
1. **No downgrades** — CUDA 13+, PyTorch 2.9+, vLLM 0.18.1+
2. **No skipping compilation** — build from source
3. **Clear all changes with Mike before making them** — no autonomous commits
4. **One build at a time** — Mike reports failure → assess → report → Mike decides

View File

@@ -249,15 +249,29 @@ RUN pip uninstall -y pynvml && pip install nvidia-ml-py
# access to both HBM and LPDDR (EGM) memory on GH200 systems.
#
# The managed_alloc.cu provides a PyTorch pluggable allocator that uses
# cudaMallocManaged instead of cudaMalloc. vllm_managed_mem.py is a
# launcher that swaps the allocator before any CUDA operations and patches
# vLLM's memory validation to understand the larger managed memory space.
# cudaMallocManaged instead of cudaMalloc. Key features:
# - cudaMemAdviseSetPreferredLocation(GPU): keep pages on GPU
# - cudaMemAdviseSetAccessedBy(CPU): CPU reads over C2C NVLink without
# migrating pages back to system RAM (prevents OOM on EGM systems)
# - cudaMemPrefetchAsync(GPU): actively migrates pages to GPU immediately,
# so model weight writes go to HBM/EGM, not system RAM
#
# vllm_managed_mem.py is the launcher that swaps the allocator before any
# CUDA operations and patches vLLM's memory validation to understand the
# larger managed memory space.
#
# sitecustomize.py is auto-loaded by Python in ALL subprocesses (including
# vLLM's EngineCore). It applies the allocator swap and torch.cuda patches
# before any CUDA operations in spawned processes.
# ==============================================================================
COPY managed_alloc.cu /tmp/managed_alloc.cu
RUN nvcc -shared -o /usr/local/lib/libmanaged_alloc.so \
/tmp/managed_alloc.cu -Xcompiler -fPIC && rm /tmp/managed_alloc.cu
COPY vllm_managed_mem.py /usr/local/bin/vllm_managed_mem.py
RUN chmod +x /usr/local/bin/vllm_managed_mem.py
# sitecustomize.py is auto-loaded by Python before any other imports.
# This ensures CMM patches apply in ALL subprocesses (EngineCore, etc.)
COPY sitecustomize.py /usr/local/lib/python3.12/dist-packages/sitecustomize.py
# API server entrypoint
# ENTRYPOINT ["vllm", "serve"]

View File

@@ -1,6 +1,15 @@
// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
// Compatible with CUDA 13+ (uses cudaMemLocation API)
//
// Key design decisions for GH200 EGM:
// 1. cudaMallocManaged → allocations can page-fault across HBM + EGM
// 2. cudaMemAdviseSetPreferredLocation(GPU) → driver prefers keeping pages on GPU
// 3. cudaMemAdviseSetAccessedBy(CPU) → CPU can access over C2C NVLink without
// triggering page migration back to system RAM (critical: prevents OOM)
// 4. cudaMemPrefetchAsync(GPU) → actively migrates pages to GPU immediately,
// so subsequent writes go to HBM/EGM, not system RAM (prevents OOM on
// systems where EGM carved out most of system memory)
#include <cuda_runtime.h>
#include <stdio.h>
@@ -28,13 +37,44 @@ void* managed_malloc(size_t size, int device, cudaStream_t stream) {
return nullptr;
}
// Advise the driver to prefer GPU placement initially.
// On GH200 with EGM, the hardware will migrate pages as needed.
// CUDA 13+ uses cudaMemLocation struct instead of int for device
cudaMemLocation location;
location.type = cudaMemLocationTypeDevice;
location.id = device;
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, location);
cudaMemLocation gpu_loc;
gpu_loc.type = cudaMemLocationTypeDevice;
gpu_loc.id = device;
// Advise: prefer GPU placement. On GH200 with EGM, the hardware will
// migrate pages as needed, but the driver tries to keep them on GPU.
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, gpu_loc);
// Advise: CPU will access this memory too. On GH200, this sets up
// remote mapping over C2C NVLink so CPU can read/write without
// triggering page migration back to system RAM. This is CRITICAL
// to prevent OOM on EGM systems where most system RAM was carved
// out for the GPU.
cudaMemLocation cpu_loc;
cpu_loc.type = cudaMemLocationTypeHost;
cpu_loc.id = cudaCpuDeviceId;
cudaMemAdvise(ptr, size, cudaMemAdviseSetAccessedBy, cpu_loc);
// Prefetch to GPU immediately. This actively migrates the virtual
// pages to the GPU side so that subsequent writes (e.g., model weight
// loading) go directly to HBM/EGM instead of pinning system RAM.
// Without this, the first write to each page faults into system RAM,
// which causes OOM when the OS only has ~102 GiB after EGM carveout.
//
// The prefetch is asynchronous on the given stream, so it won't block
// the calling thread. Subsequent operations on the same stream will
// wait for the prefetch to complete.
if (size > 0) {
err = cudaMemPrefetchAsync(ptr, size, gpu_loc, stream);
if (err != cudaSuccess) {
// Non-fatal: prefetch failure shouldn't prevent allocation.
// Pages will still be migrated on demand.
fprintf(stderr, "[managed_alloc] cudaMemPrefetchAsync warning: %s "
"(size=%.2f GiB, will use on-demand migration)\n",
cudaGetErrorString(err), (double)size / (1024.0*1024.0*1024.0));
}
}
return ptr;
}
@@ -42,7 +82,8 @@ void* managed_malloc(size_t size, int device, cudaStream_t stream) {
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
if (ptr != nullptr) {
// Sync the stream before freeing to avoid use-after-free
// Sync the stream before freeing to avoid use-after-free with
// managed memory (in-flight page faults can race with deallocation).
if (stream != nullptr) {
cudaStreamSynchronize(stream);
}

114
vllm/sitecustomize.py Normal file
View File

@@ -0,0 +1,114 @@
"""
sitecustomize.py - Auto-loaded by Python before any other imports.
Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator.
Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot.
This MUST run before any torch.cuda calls in any subprocess.
Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode).
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
"""
import os
import sys
# Only activate in CMM mode
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if _MANAGED_TOTAL:
import torch
# Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops
_lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so')
if os.path.exists(_lib_path):
try:
import ctypes
lib = ctypes.CDLL(_lib_path)
if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'):
alloc = torch.cuda.memory.CUDAPluggableAllocator(
_lib_path, 'managed_malloc', 'managed_free'
)
torch.cuda.memory.change_current_allocator(alloc)
print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr)
except Exception as e:
print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr)
else:
print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr)
# Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support
_original_reset_peak = torch.cuda.reset_peak_memory_stats
_original_memory_stats = torch.cuda.memory_stats
_original_max_memory_allocated = torch.cuda.max_memory_allocated
def _patched_reset_peak_memory_stats(device=None):
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
pass
def _patched_memory_stats(device=None):
"""Return minimal stats dict to avoid crashes."""
try:
return _original_memory_stats(device)
except RuntimeError:
return {
"allocated_bytes.all.current": 0,
"allocated_bytes.all.peak": 0,
"reserved_bytes.all.current": 0,
"reserved_bytes.all.peak": 0,
"num_alloc_retries": 0,
"num_ooms": 0,
}
def _patched_max_memory_allocated(device=None):
"""Return 0 since we can't track peak with pluggable allocator."""
try:
return _original_max_memory_allocated(device)
except RuntimeError:
return 0
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
torch.cuda.memory_stats = _patched_memory_stats
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
# Patch accelerator aliases (PyTorch 2.11+)
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
if hasattr(torch.accelerator, 'memory_stats'):
torch.accelerator.memory_stats = _patched_memory_stats
if hasattr(torch.accelerator, 'max_memory_allocated'):
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
# Step 3: Patch MemorySnapshot.measure() to report managed memory
try:
from vllm.utils.mem_utils import MemorySnapshot
_original_measure = MemorySnapshot.measure
def _patched_measure(self):
_original_measure(self)
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
managed_total = int(float(managed_total_gb) * (1024**3))
self.total_memory = managed_total
self.free_memory = managed_total - self.cuda_memory
MemorySnapshot.measure = _patched_measure
except ImportError:
pass # vllm not loaded yet, will be patched by vllm_managed_mem.py
# Step 4: Patch request_memory to skip free-memory check for managed
try:
import vllm.v1.worker.utils as worker_utils
_original_request_memory = worker_utils.request_memory
def _patched_request_memory(init_snapshot, cache_config):
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
total_bytes = float(managed_total_gb) * (1024**3)
gpu_util = cache_config.gpu_memory_utilization
requested = int(total_bytes * gpu_util)
return requested
else:
return _original_request_memory(init_snapshot, cache_config)
worker_utils.request_memory = _patched_request_memory
except ImportError:
pass # vllm not loaded yet
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)

View File

@@ -19,11 +19,20 @@ import os
import sys
import ctypes
def get_total_managed_memory_gb():
"""
Calculate total memory available via managed allocations on GH200.
Parses /proc/iomem to find NVIDIA EGM regions + HBM.
First checks MANAGED_MEMORY_TOTAL_GB env var (explicit override).
Then tries /proc/iomem (requires --privileged container).
Falls back to HBM only if neither works.
"""
# Explicit override takes priority
env_val = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if env_val:
return float(env_val)
# Try /proc/iomem for EGM regions (requires --privileged or host mount)
egm_bytes = 0
try:
with open('/proc/iomem', 'r') as f:
@@ -34,19 +43,26 @@ def get_total_managed_memory_gb():
start = int(start_s, 16)
end = int(end_s, 16)
egm_bytes += (end - start + 1)
except (PermissionError, FileNotFoundError):
except (PermissionError, FileNotFoundError, ValueError):
pass
egm_gb = egm_bytes / (1024**3)
# HBM is always there via normal cudaMalloc path
# cudaMallocManaged can span both HBM + EGM
return egm_gb
if egm_gb > 0:
# EGM detected via iomem, add HBM
import torch
_, hbm_total = torch.cuda.mem_get_info(0)
hbm_gb = hbm_total / (1024**3)
return egm_gb + hbm_gb
# No EGM detected - return 0 (caller should handle)
return 0
def swap_allocator():
"""
Replace PyTorch's default CUDA allocator with our managed memory allocator.
This MUST happen before any CUDA tensors are created.
If sitecustomize.py already swapped it, this is a no-op.
"""
lib_path = os.environ.get(
'MANAGED_ALLOC_LIB',
@@ -64,43 +80,125 @@ def swap_allocator():
lib = ctypes.CDLL(lib_path)
assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found"
assert hasattr(lib, 'managed_free'), "managed_free symbol not found"
print(f"[managed_mem] Loaded allocator from {lib_path}", file=sys.stderr)
except Exception as e:
print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr)
sys.exit(1)
import torch
# This must happen before ANY cuda operation
alloc = torch.cuda.memory.CUDAPluggableAllocator(
lib_path, 'managed_malloc', 'managed_free'
)
torch.cuda.memory.change_current_allocator(alloc)
try:
alloc = torch.cuda.memory.CUDAPluggableAllocator(
lib_path, 'managed_malloc', 'managed_free'
)
torch.cuda.memory.change_current_allocator(alloc)
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
except RuntimeError as e:
if "already initialized" in str(e):
print(f"[managed_mem] Allocator already initialized (sitecustomize)",
file=sys.stderr)
else:
raise
egm_gb = get_total_managed_memory_gb()
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
print(f"[managed_mem] Detected ~{egm_gb:.0f} GiB EGM memory", file=sys.stderr)
print(f"[managed_mem] Total addressable: ~{egm_gb + 96:.0f} GiB "
f"(EGM + HBM)", file=sys.stderr)
def patch_memory_snapshot():
"""
Patch MemorySnapshot.measure() to report the full managed memory
instead of just HBM.
This is the core fix: cudaMemGetInfo only reports HBM (~96 GiB),
but with cudaMallocManaged we can address HBM + EGM (~474 GiB).
We override the snapshot so all downstream code (request_memory,
KV cache sizing, etc.) sees the full managed memory capacity.
"""
from vllm.utils.mem_utils import MemorySnapshot
_original_measure = MemorySnapshot.measure
def patched_measure(self):
_original_measure(self)
# Override total_memory with managed memory capacity
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
managed_total = int(float(managed_total_gb) * (1024**3))
self.total_memory = managed_total
# free_memory = total - what CUDA is using
self.free_memory = managed_total - self.cuda_memory
print(f"[managed_mem] MemorySnapshot patched: "
f"total={self.total_memory / (1024**3):.0f} GiB, "
f"free={self.free_memory / (1024**3):.0f} GiB",
file=sys.stderr)
MemorySnapshot.measure = patched_measure
print(f"[managed_mem] Patched MemorySnapshot.measure()", file=sys.stderr)
def patch_torch_memory_tracking():
"""
Patch PyTorch memory tracking functions that CUDAPluggableAllocator
doesn't support. If sitecustomize already patched them, this is a no-op.
"""
import torch
# Check if already patched by sitecustomize
if torch.cuda.reset_peak_memory_stats.__name__ == '_patched_reset_peak_memory_stats':
print(f"[managed_mem] torch.cuda already patched (sitecustomize)", file=sys.stderr)
return
_original_reset_peak = torch.cuda.reset_peak_memory_stats
_original_memory_stats = torch.cuda.memory_stats
_original_max_memory_allocated = torch.cuda.max_memory_allocated
def _patched_reset_peak_memory_stats(device=None):
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
pass
def _patched_memory_stats(device=None):
"""Return minimal stats dict to avoid crashes."""
try:
return _original_memory_stats(device)
except RuntimeError:
return {
"allocated_bytes.all.current": 0,
"allocated_bytes.all.peak": 0,
"reserved_bytes.all.current": 0,
"reserved_bytes.all.peak": 0,
"num_alloc_retries": 0,
"num_ooms": 0,
}
def _patched_max_memory_allocated(device=None):
"""Return 0 since we can't track peak with pluggable allocator."""
try:
return _original_max_memory_allocated(device)
except RuntimeError:
return 0
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
torch.cuda.memory_stats = _patched_memory_stats
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
if hasattr(torch.accelerator, 'memory_stats'):
torch.accelerator.memory_stats = _patched_memory_stats
if hasattr(torch.accelerator, 'max_memory_allocated'):
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
print(f"[managed_mem] Patched torch.cuda memory tracking (no-op stubs)",
file=sys.stderr)
def patch_vllm_memory_check():
"""
Monkey-patch vLLM's memory validation to understand managed memory.
vLLM checks cudaMemGetInfo and rejects startup if free < requested.
With managed memory, the real capacity is much larger than what
cudaMemGetInfo reports (it only sees HBM + a tiny EGM sliver).
With managed memory, cudaMemGetInfo only reports HBM, so the free-memory
check would fail. When MANAGED_MEMORY_TOTAL_GB is set, we skip that check.
"""
import vllm.v1.worker.utils as worker_utils
_original_request_memory = worker_utils.request_memory
def patched_request_memory(init_snapshot, cache_config):
"""
Override memory validation:
- Read the MANAGED_MEMORY_TOTAL_GB env var (set by us)
- If set, skip the free-memory check and just return the requested size
"""
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
total_bytes = float(managed_total_gb) * (1024**3)
@@ -114,40 +212,7 @@ def patch_vllm_memory_check():
return _original_request_memory(init_snapshot, cache_config)
worker_utils.request_memory = patched_request_memory
print(f"[managed_mem] Patched vLLM memory check", file=sys.stderr)
def patch_vllm_memory_snapshot():
"""
Patch the memory snapshot to report managed-aware totals.
Without this, vLLM's worker sees ~95 GiB total and panics.
"""
import vllm.v1.worker.gpu_worker as gpu_worker
_original_init_device = gpu_worker.GpuWorker.init_device
def patched_init_device(self):
"""Patch init_device to override memory reporting."""
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
# Set a permissive gpu_memory_utilization on the cache_config
# so the original code doesn't reject us
original_util = self.cache_config.gpu_memory_utilization
self.cache_config.gpu_memory_utilization = 0.5 # always passes
_original_init_device(self)
# Now fix up the requested_memory with the real value
total_bytes = float(managed_total_gb) * (1024**3)
self.requested_memory = int(total_bytes * original_util)
self.cache_config.gpu_memory_utilization = original_util
print(f"[managed_mem] Worker memory overridden to "
f"{self.requested_memory / (1024**3):.1f} GiB", file=sys.stderr)
else:
_original_init_device(self)
gpu_worker.GpuWorker.init_device = patched_init_device
print(f"[managed_mem] Patched vLLM worker init_device", file=sys.stderr)
print(f"[managed_mem] Patched vLLM request_memory", file=sys.stderr)
def main():
@@ -155,17 +220,28 @@ def main():
swap_allocator()
# Step 2: Calculate total managed memory and export it
egm_gb = get_total_managed_memory_gb()
total_managed_gb = egm_gb + 96 # EGM + HBM
total_managed_gb = get_total_managed_memory_gb()
if total_managed_gb <= 0:
print(f"[managed_mem] WARNING: No managed memory detected! "
f"Set MANAGED_MEMORY_TOTAL_GB env var explicitly.",
file=sys.stderr)
sys.exit(1)
os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb)
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
file=sys.stderr)
# Step 3: Patch vLLM memory checks
patch_vllm_memory_check()
patch_vllm_memory_snapshot()
# Step 3: Patch PyTorch memory tracking (pluggable allocator doesn't support all ops)
patch_torch_memory_tracking()
# Step 4: Launch vLLM's API server with remaining args
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
# This is critical - without it, all downstream code only sees HBM
patch_memory_snapshot()
# Step 5: Patch request_memory as a safety net
patch_vllm_memory_check()
# Step 6: Launch vLLM's API server with remaining args
sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:]
print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}",
file=sys.stderr)