2026-04-07 21:19:39 +00:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
"""
|
|
|
|
|
|
vllm_managed_mem.py - Launch vLLM with cudaMallocManaged allocator
|
|
|
|
|
|
|
|
|
|
|
|
This MUST be the very first thing that runs before any torch.cuda calls.
|
|
|
|
|
|
It swaps PyTorch's CUDA allocator to use cudaMallocManaged, enabling
|
|
|
|
|
|
transparent page-fault access to EGM memory on GH200.
|
|
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
|
python vllm_managed_mem.py [all normal vllm serve arguments]
|
|
|
|
|
|
|
|
|
|
|
|
Example:
|
|
|
|
|
|
python vllm_managed_mem.py --model google/gemma-4-31B-it \
|
|
|
|
|
|
--host 0.0.0.0 --port 80 --gpu-memory-utilization 0.90 \
|
|
|
|
|
|
--enforce-eager --max-model-len 32768
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
|
import sys
|
|
|
|
|
|
import ctypes
|
|
|
|
|
|
|
2026-04-09 23:25:48 +00:00
|
|
|
|
|
2026-04-07 21:19:39 +00:00
|
|
|
|
def get_total_managed_memory_gb():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Calculate total memory available via managed allocations on GH200.
|
2026-04-09 23:25:48 +00:00
|
|
|
|
First checks MANAGED_MEMORY_TOTAL_GB env var (explicit override).
|
|
|
|
|
|
Then tries /proc/iomem (requires --privileged container).
|
|
|
|
|
|
Falls back to HBM only if neither works.
|
2026-04-07 21:19:39 +00:00
|
|
|
|
"""
|
2026-04-09 23:25:48 +00:00
|
|
|
|
# Explicit override takes priority
|
|
|
|
|
|
env_val = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
|
|
|
|
|
if env_val:
|
|
|
|
|
|
return float(env_val)
|
|
|
|
|
|
|
|
|
|
|
|
# Try /proc/iomem for EGM regions (requires --privileged or host mount)
|
2026-04-07 21:19:39 +00:00
|
|
|
|
egm_bytes = 0
|
|
|
|
|
|
try:
|
|
|
|
|
|
with open('/proc/iomem', 'r') as f:
|
|
|
|
|
|
for line in f:
|
|
|
|
|
|
if 'System RAM (NVIDIA)' in line:
|
|
|
|
|
|
parts = line.strip().split(':')[0].strip()
|
|
|
|
|
|
start_s, end_s = parts.split('-')
|
|
|
|
|
|
start = int(start_s, 16)
|
|
|
|
|
|
end = int(end_s, 16)
|
|
|
|
|
|
egm_bytes += (end - start + 1)
|
2026-04-09 23:25:48 +00:00
|
|
|
|
except (PermissionError, FileNotFoundError, ValueError):
|
2026-04-07 21:19:39 +00:00
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
egm_gb = egm_bytes / (1024**3)
|
2026-04-09 23:25:48 +00:00
|
|
|
|
if egm_gb > 0:
|
|
|
|
|
|
# EGM detected via iomem, add HBM
|
|
|
|
|
|
import torch
|
|
|
|
|
|
_, hbm_total = torch.cuda.mem_get_info(0)
|
|
|
|
|
|
hbm_gb = hbm_total / (1024**3)
|
|
|
|
|
|
return egm_gb + hbm_gb
|
|
|
|
|
|
|
|
|
|
|
|
# No EGM detected - return 0 (caller should handle)
|
|
|
|
|
|
return 0
|
2026-04-07 21:19:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def swap_allocator():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Replace PyTorch's default CUDA allocator with our managed memory allocator.
|
|
|
|
|
|
This MUST happen before any CUDA tensors are created.
|
2026-04-09 23:25:48 +00:00
|
|
|
|
If sitecustomize.py already swapped it, this is a no-op.
|
2026-04-07 21:19:39 +00:00
|
|
|
|
"""
|
|
|
|
|
|
lib_path = os.environ.get(
|
|
|
|
|
|
'MANAGED_ALLOC_LIB',
|
|
|
|
|
|
'/usr/local/lib/libmanaged_alloc.so'
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(lib_path):
|
|
|
|
|
|
print(f"[managed_mem] ERROR: {lib_path} not found!", file=sys.stderr)
|
|
|
|
|
|
print(f"[managed_mem] Build it with: nvcc -shared -o {lib_path} "
|
|
|
|
|
|
f"managed_alloc.cu -Xcompiler -fPIC", file=sys.stderr)
|
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
# Verify the library loads
|
|
|
|
|
|
try:
|
|
|
|
|
|
lib = ctypes.CDLL(lib_path)
|
|
|
|
|
|
assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found"
|
|
|
|
|
|
assert hasattr(lib, 'managed_free'), "managed_free symbol not found"
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr)
|
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
2026-04-09 23:25:48 +00:00
|
|
|
|
try:
|
|
|
|
|
|
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
|
|
|
|
|
lib_path, 'managed_malloc', 'managed_free'
|
|
|
|
|
|
)
|
|
|
|
|
|
torch.cuda.memory.change_current_allocator(alloc)
|
|
|
|
|
|
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
if "already initialized" in str(e):
|
|
|
|
|
|
print(f"[managed_mem] Allocator already initialized (sitecustomize)",
|
|
|
|
|
|
file=sys.stderr)
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_memory_snapshot():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Patch MemorySnapshot.measure() to report the full managed memory
|
|
|
|
|
|
instead of just HBM.
|
|
|
|
|
|
|
|
|
|
|
|
This is the core fix: cudaMemGetInfo only reports HBM (~96 GiB),
|
|
|
|
|
|
but with cudaMallocManaged we can address HBM + EGM (~474 GiB).
|
|
|
|
|
|
We override the snapshot so all downstream code (request_memory,
|
|
|
|
|
|
KV cache sizing, etc.) sees the full managed memory capacity.
|
|
|
|
|
|
"""
|
|
|
|
|
|
from vllm.utils.mem_utils import MemorySnapshot
|
2026-04-07 21:19:39 +00:00
|
|
|
|
|
2026-04-09 23:25:48 +00:00
|
|
|
|
_original_measure = MemorySnapshot.measure
|
|
|
|
|
|
|
|
|
|
|
|
def patched_measure(self):
|
|
|
|
|
|
_original_measure(self)
|
|
|
|
|
|
# Override total_memory with managed memory capacity
|
|
|
|
|
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
|
|
|
|
|
if managed_total_gb:
|
|
|
|
|
|
managed_total = int(float(managed_total_gb) * (1024**3))
|
|
|
|
|
|
self.total_memory = managed_total
|
|
|
|
|
|
# free_memory = total - what CUDA is using
|
|
|
|
|
|
self.free_memory = managed_total - self.cuda_memory
|
|
|
|
|
|
print(f"[managed_mem] MemorySnapshot patched: "
|
|
|
|
|
|
f"total={self.total_memory / (1024**3):.0f} GiB, "
|
|
|
|
|
|
f"free={self.free_memory / (1024**3):.0f} GiB",
|
|
|
|
|
|
file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
MemorySnapshot.measure = patched_measure
|
|
|
|
|
|
print(f"[managed_mem] Patched MemorySnapshot.measure()", file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_torch_memory_tracking():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Patch PyTorch memory tracking functions that CUDAPluggableAllocator
|
|
|
|
|
|
doesn't support. If sitecustomize already patched them, this is a no-op.
|
|
|
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
# Check if already patched by sitecustomize
|
|
|
|
|
|
if torch.cuda.reset_peak_memory_stats.__name__ == '_patched_reset_peak_memory_stats':
|
|
|
|
|
|
print(f"[managed_mem] torch.cuda already patched (sitecustomize)", file=sys.stderr)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
_original_reset_peak = torch.cuda.reset_peak_memory_stats
|
|
|
|
|
|
_original_memory_stats = torch.cuda.memory_stats
|
|
|
|
|
|
_original_max_memory_allocated = torch.cuda.max_memory_allocated
|
|
|
|
|
|
|
|
|
|
|
|
def _patched_reset_peak_memory_stats(device=None):
|
|
|
|
|
|
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def _patched_memory_stats(device=None):
|
|
|
|
|
|
"""Return minimal stats dict to avoid crashes."""
|
|
|
|
|
|
try:
|
|
|
|
|
|
return _original_memory_stats(device)
|
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
|
return {
|
|
|
|
|
|
"allocated_bytes.all.current": 0,
|
|
|
|
|
|
"allocated_bytes.all.peak": 0,
|
|
|
|
|
|
"reserved_bytes.all.current": 0,
|
|
|
|
|
|
"reserved_bytes.all.peak": 0,
|
|
|
|
|
|
"num_alloc_retries": 0,
|
|
|
|
|
|
"num_ooms": 0,
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _patched_max_memory_allocated(device=None):
|
|
|
|
|
|
"""Return 0 since we can't track peak with pluggable allocator."""
|
|
|
|
|
|
try:
|
|
|
|
|
|
return _original_max_memory_allocated(device)
|
|
|
|
|
|
except RuntimeError:
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
|
|
|
|
|
torch.cuda.memory_stats = _patched_memory_stats
|
|
|
|
|
|
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
|
|
|
|
|
|
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
|
|
|
|
|
if hasattr(torch.accelerator, 'memory_stats'):
|
|
|
|
|
|
torch.accelerator.memory_stats = _patched_memory_stats
|
|
|
|
|
|
if hasattr(torch.accelerator, 'max_memory_allocated'):
|
|
|
|
|
|
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
|
|
|
|
|
|
|
|
|
|
|
|
print(f"[managed_mem] Patched torch.cuda memory tracking (no-op stubs)",
|
|
|
|
|
|
file=sys.stderr)
|
2026-04-07 21:19:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def patch_vllm_memory_check():
|
|
|
|
|
|
"""
|
|
|
|
|
|
Monkey-patch vLLM's memory validation to understand managed memory.
|
|
|
|
|
|
|
2026-04-09 23:25:48 +00:00
|
|
|
|
With managed memory, cudaMemGetInfo only reports HBM, so the free-memory
|
|
|
|
|
|
check would fail. When MANAGED_MEMORY_TOTAL_GB is set, we skip that check.
|
2026-04-07 21:19:39 +00:00
|
|
|
|
"""
|
|
|
|
|
|
import vllm.v1.worker.utils as worker_utils
|
|
|
|
|
|
|
|
|
|
|
|
_original_request_memory = worker_utils.request_memory
|
|
|
|
|
|
|
|
|
|
|
|
def patched_request_memory(init_snapshot, cache_config):
|
|
|
|
|
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
|
|
|
|
|
if managed_total_gb:
|
|
|
|
|
|
total_bytes = float(managed_total_gb) * (1024**3)
|
|
|
|
|
|
gpu_util = cache_config.gpu_memory_utilization
|
|
|
|
|
|
requested = int(total_bytes * gpu_util)
|
|
|
|
|
|
print(f"[managed_mem] Overriding memory request: "
|
|
|
|
|
|
f"{float(managed_total_gb):.0f} GiB × {gpu_util} = "
|
|
|
|
|
|
f"{requested / (1024**3):.1f} GiB", file=sys.stderr)
|
|
|
|
|
|
return requested
|
|
|
|
|
|
else:
|
|
|
|
|
|
return _original_request_memory(init_snapshot, cache_config)
|
|
|
|
|
|
|
|
|
|
|
|
worker_utils.request_memory = patched_request_memory
|
2026-04-09 23:25:48 +00:00
|
|
|
|
print(f"[managed_mem] Patched vLLM request_memory", file=sys.stderr)
|
2026-04-07 21:19:39 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
2026-04-11 02:15:09 +00:00
|
|
|
|
# Step 1: NO global allocator swap — model weights stay in HBM.
|
|
|
|
|
|
# KV cache uses cudaMallocManaged directly via
|
|
|
|
|
|
# VLLM_KV_CACHE_USE_MANAGED_MEMORY env var (set by sitecustomize.py).
|
|
|
|
|
|
# The global allocator swap broke cuBLAS GEMM operations because
|
|
|
|
|
|
# intermediate compute tensors ended up in managed memory.
|
|
|
|
|
|
print(f"[managed_mem] Using targeted KV cache managed allocation "
|
|
|
|
|
|
f"(no global allocator swap)", file=sys.stderr)
|
2026-04-07 21:19:39 +00:00
|
|
|
|
|
|
|
|
|
|
# Step 2: Calculate total managed memory and export it
|
2026-04-09 23:25:48 +00:00
|
|
|
|
total_managed_gb = get_total_managed_memory_gb()
|
|
|
|
|
|
if total_managed_gb <= 0:
|
|
|
|
|
|
print(f"[managed_mem] WARNING: No managed memory detected! "
|
|
|
|
|
|
f"Set MANAGED_MEMORY_TOTAL_GB env var explicitly.",
|
|
|
|
|
|
file=sys.stderr)
|
|
|
|
|
|
sys.exit(1)
|
|
|
|
|
|
|
2026-04-07 21:19:39 +00:00
|
|
|
|
os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb)
|
|
|
|
|
|
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
|
|
|
|
|
|
file=sys.stderr)
|
|
|
|
|
|
|
2026-04-11 02:15:09 +00:00
|
|
|
|
# Step 3: No torch.cuda memory tracking patches needed —
|
|
|
|
|
|
# we're not using CUDAPluggableAllocator anymore.
|
2026-04-09 23:25:48 +00:00
|
|
|
|
|
|
|
|
|
|
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
|
|
|
|
|
|
# This is critical - without it, all downstream code only sees HBM
|
|
|
|
|
|
patch_memory_snapshot()
|
|
|
|
|
|
|
|
|
|
|
|
# Step 5: Patch request_memory as a safety net
|
2026-04-07 21:19:39 +00:00
|
|
|
|
patch_vllm_memory_check()
|
|
|
|
|
|
|
2026-04-09 23:25:48 +00:00
|
|
|
|
# Step 6: Launch vLLM's API server with remaining args
|
2026-04-07 21:19:39 +00:00
|
|
|
|
sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:]
|
|
|
|
|
|
print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}",
|
|
|
|
|
|
file=sys.stderr)
|
|
|
|
|
|
|
|
|
|
|
|
# Import and run
|
|
|
|
|
|
from vllm.entrypoints.openai.api_server import run_server, FlexibleArgumentParser
|
|
|
|
|
|
import uvloop
|
|
|
|
|
|
|
|
|
|
|
|
parser = FlexibleArgumentParser(
|
|
|
|
|
|
description="vLLM OpenAI-compatible API server (managed memory)"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# Use vLLM's own argument parser
|
|
|
|
|
|
from vllm.entrypoints.openai.api_server import make_arg_parser
|
|
|
|
|
|
parser = make_arg_parser(parser)
|
|
|
|
|
|
args = parser.parse_args(sys.argv[1:])
|
|
|
|
|
|
|
|
|
|
|
|
uvloop.run(run_server(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
main()
|