Files
grace-gpu-containers/vllm/vllm_managed_mem.py
biondizzle 2757bffcb6 Add cudaMallocManaged allocator for GH200 EGM support
- managed_alloc.cu: PyTorch pluggable allocator using cudaMallocManaged
- vllm_managed_mem.py: Launcher that patches vLLM for managed memory
- Dockerfile: Build and install managed memory components

This enables vLLM to use cudaMallocManaged for transparent page-fault
access to both HBM (~96 GiB) and LPDDR (EGM, up to 480 GiB additional)
on GH200 systems with Extended GPU Memory enabled.

Experimental branch: v0.19.0-cmm
2026-04-07 21:19:39 +00:00

191 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
vllm_managed_mem.py - Launch vLLM with cudaMallocManaged allocator
This MUST be the very first thing that runs before any torch.cuda calls.
It swaps PyTorch's CUDA allocator to use cudaMallocManaged, enabling
transparent page-fault access to EGM memory on GH200.
Usage:
python vllm_managed_mem.py [all normal vllm serve arguments]
Example:
python vllm_managed_mem.py --model google/gemma-4-31B-it \
--host 0.0.0.0 --port 80 --gpu-memory-utilization 0.90 \
--enforce-eager --max-model-len 32768
"""
import os
import sys
import ctypes
def get_total_managed_memory_gb():
"""
Calculate total memory available via managed allocations on GH200.
Parses /proc/iomem to find NVIDIA EGM regions + HBM.
"""
egm_bytes = 0
try:
with open('/proc/iomem', 'r') as f:
for line in f:
if 'System RAM (NVIDIA)' in line:
parts = line.strip().split(':')[0].strip()
start_s, end_s = parts.split('-')
start = int(start_s, 16)
end = int(end_s, 16)
egm_bytes += (end - start + 1)
except (PermissionError, FileNotFoundError):
pass
egm_gb = egm_bytes / (1024**3)
# HBM is always there via normal cudaMalloc path
# cudaMallocManaged can span both HBM + EGM
return egm_gb
def swap_allocator():
"""
Replace PyTorch's default CUDA allocator with our managed memory allocator.
This MUST happen before any CUDA tensors are created.
"""
lib_path = os.environ.get(
'MANAGED_ALLOC_LIB',
'/usr/local/lib/libmanaged_alloc.so'
)
if not os.path.exists(lib_path):
print(f"[managed_mem] ERROR: {lib_path} not found!", file=sys.stderr)
print(f"[managed_mem] Build it with: nvcc -shared -o {lib_path} "
f"managed_alloc.cu -Xcompiler -fPIC", file=sys.stderr)
sys.exit(1)
# Verify the library loads
try:
lib = ctypes.CDLL(lib_path)
assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found"
assert hasattr(lib, 'managed_free'), "managed_free symbol not found"
print(f"[managed_mem] Loaded allocator from {lib_path}", file=sys.stderr)
except Exception as e:
print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr)
sys.exit(1)
import torch
# This must happen before ANY cuda operation
alloc = torch.cuda.memory.CUDAPluggableAllocator(
lib_path, 'managed_malloc', 'managed_free'
)
torch.cuda.memory.change_current_allocator(alloc)
egm_gb = get_total_managed_memory_gb()
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
print(f"[managed_mem] Detected ~{egm_gb:.0f} GiB EGM memory", file=sys.stderr)
print(f"[managed_mem] Total addressable: ~{egm_gb + 96:.0f} GiB "
f"(EGM + HBM)", file=sys.stderr)
def patch_vllm_memory_check():
"""
Monkey-patch vLLM's memory validation to understand managed memory.
vLLM checks cudaMemGetInfo and rejects startup if free < requested.
With managed memory, the real capacity is much larger than what
cudaMemGetInfo reports (it only sees HBM + a tiny EGM sliver).
"""
import vllm.v1.worker.utils as worker_utils
_original_request_memory = worker_utils.request_memory
def patched_request_memory(init_snapshot, cache_config):
"""
Override memory validation:
- Read the MANAGED_MEMORY_TOTAL_GB env var (set by us)
- If set, skip the free-memory check and just return the requested size
"""
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
total_bytes = float(managed_total_gb) * (1024**3)
gpu_util = cache_config.gpu_memory_utilization
requested = int(total_bytes * gpu_util)
print(f"[managed_mem] Overriding memory request: "
f"{float(managed_total_gb):.0f} GiB × {gpu_util} = "
f"{requested / (1024**3):.1f} GiB", file=sys.stderr)
return requested
else:
return _original_request_memory(init_snapshot, cache_config)
worker_utils.request_memory = patched_request_memory
print(f"[managed_mem] Patched vLLM memory check", file=sys.stderr)
def patch_vllm_memory_snapshot():
"""
Patch the memory snapshot to report managed-aware totals.
Without this, vLLM's worker sees ~95 GiB total and panics.
"""
import vllm.v1.worker.gpu_worker as gpu_worker
_original_init_device = gpu_worker.GpuWorker.init_device
def patched_init_device(self):
"""Patch init_device to override memory reporting."""
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
# Set a permissive gpu_memory_utilization on the cache_config
# so the original code doesn't reject us
original_util = self.cache_config.gpu_memory_utilization
self.cache_config.gpu_memory_utilization = 0.5 # always passes
_original_init_device(self)
# Now fix up the requested_memory with the real value
total_bytes = float(managed_total_gb) * (1024**3)
self.requested_memory = int(total_bytes * original_util)
self.cache_config.gpu_memory_utilization = original_util
print(f"[managed_mem] Worker memory overridden to "
f"{self.requested_memory / (1024**3):.1f} GiB", file=sys.stderr)
else:
_original_init_device(self)
gpu_worker.GpuWorker.init_device = patched_init_device
print(f"[managed_mem] Patched vLLM worker init_device", file=sys.stderr)
def main():
# Step 1: Swap allocator BEFORE any CUDA ops
swap_allocator()
# Step 2: Calculate total managed memory and export it
egm_gb = get_total_managed_memory_gb()
total_managed_gb = egm_gb + 96 # EGM + HBM
os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb)
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
file=sys.stderr)
# Step 3: Patch vLLM memory checks
patch_vllm_memory_check()
patch_vllm_memory_snapshot()
# Step 4: Launch vLLM's API server with remaining args
sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:]
print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}",
file=sys.stderr)
# Import and run
from vllm.entrypoints.openai.api_server import run_server, FlexibleArgumentParser
import uvloop
parser = FlexibleArgumentParser(
description="vLLM OpenAI-compatible API server (managed memory)"
)
# Use vLLM's own argument parser
from vllm.entrypoints.openai.api_server import make_arg_parser
parser = make_arg_parser(parser)
args = parser.parse_args(sys.argv[1:])
uvloop.run(run_server(args))
if __name__ == '__main__':
main()