- managed_alloc.cu: PyTorch pluggable allocator using cudaMallocManaged - vllm_managed_mem.py: Launcher that patches vLLM for managed memory - Dockerfile: Build and install managed memory components This enables vLLM to use cudaMallocManaged for transparent page-fault access to both HBM (~96 GiB) and LPDDR (EGM, up to 480 GiB additional) on GH200 systems with Extended GPU Memory enabled. Experimental branch: v0.19.0-cmm
191 lines
6.8 KiB
Python
191 lines
6.8 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
vllm_managed_mem.py - Launch vLLM with cudaMallocManaged allocator
|
||
|
||
This MUST be the very first thing that runs before any torch.cuda calls.
|
||
It swaps PyTorch's CUDA allocator to use cudaMallocManaged, enabling
|
||
transparent page-fault access to EGM memory on GH200.
|
||
|
||
Usage:
|
||
python vllm_managed_mem.py [all normal vllm serve arguments]
|
||
|
||
Example:
|
||
python vllm_managed_mem.py --model google/gemma-4-31B-it \
|
||
--host 0.0.0.0 --port 80 --gpu-memory-utilization 0.90 \
|
||
--enforce-eager --max-model-len 32768
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import ctypes
|
||
|
||
def get_total_managed_memory_gb():
|
||
"""
|
||
Calculate total memory available via managed allocations on GH200.
|
||
Parses /proc/iomem to find NVIDIA EGM regions + HBM.
|
||
"""
|
||
egm_bytes = 0
|
||
try:
|
||
with open('/proc/iomem', 'r') as f:
|
||
for line in f:
|
||
if 'System RAM (NVIDIA)' in line:
|
||
parts = line.strip().split(':')[0].strip()
|
||
start_s, end_s = parts.split('-')
|
||
start = int(start_s, 16)
|
||
end = int(end_s, 16)
|
||
egm_bytes += (end - start + 1)
|
||
except (PermissionError, FileNotFoundError):
|
||
pass
|
||
|
||
egm_gb = egm_bytes / (1024**3)
|
||
# HBM is always there via normal cudaMalloc path
|
||
# cudaMallocManaged can span both HBM + EGM
|
||
return egm_gb
|
||
|
||
|
||
def swap_allocator():
|
||
"""
|
||
Replace PyTorch's default CUDA allocator with our managed memory allocator.
|
||
This MUST happen before any CUDA tensors are created.
|
||
"""
|
||
lib_path = os.environ.get(
|
||
'MANAGED_ALLOC_LIB',
|
||
'/usr/local/lib/libmanaged_alloc.so'
|
||
)
|
||
|
||
if not os.path.exists(lib_path):
|
||
print(f"[managed_mem] ERROR: {lib_path} not found!", file=sys.stderr)
|
||
print(f"[managed_mem] Build it with: nvcc -shared -o {lib_path} "
|
||
f"managed_alloc.cu -Xcompiler -fPIC", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
# Verify the library loads
|
||
try:
|
||
lib = ctypes.CDLL(lib_path)
|
||
assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found"
|
||
assert hasattr(lib, 'managed_free'), "managed_free symbol not found"
|
||
print(f"[managed_mem] Loaded allocator from {lib_path}", file=sys.stderr)
|
||
except Exception as e:
|
||
print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr)
|
||
sys.exit(1)
|
||
|
||
import torch
|
||
# This must happen before ANY cuda operation
|
||
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||
lib_path, 'managed_malloc', 'managed_free'
|
||
)
|
||
torch.cuda.memory.change_current_allocator(alloc)
|
||
|
||
egm_gb = get_total_managed_memory_gb()
|
||
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
||
print(f"[managed_mem] Detected ~{egm_gb:.0f} GiB EGM memory", file=sys.stderr)
|
||
print(f"[managed_mem] Total addressable: ~{egm_gb + 96:.0f} GiB "
|
||
f"(EGM + HBM)", file=sys.stderr)
|
||
|
||
|
||
def patch_vllm_memory_check():
|
||
"""
|
||
Monkey-patch vLLM's memory validation to understand managed memory.
|
||
|
||
vLLM checks cudaMemGetInfo and rejects startup if free < requested.
|
||
With managed memory, the real capacity is much larger than what
|
||
cudaMemGetInfo reports (it only sees HBM + a tiny EGM sliver).
|
||
"""
|
||
import vllm.v1.worker.utils as worker_utils
|
||
|
||
_original_request_memory = worker_utils.request_memory
|
||
|
||
def patched_request_memory(init_snapshot, cache_config):
|
||
"""
|
||
Override memory validation:
|
||
- Read the MANAGED_MEMORY_TOTAL_GB env var (set by us)
|
||
- If set, skip the free-memory check and just return the requested size
|
||
"""
|
||
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||
if managed_total_gb:
|
||
total_bytes = float(managed_total_gb) * (1024**3)
|
||
gpu_util = cache_config.gpu_memory_utilization
|
||
requested = int(total_bytes * gpu_util)
|
||
print(f"[managed_mem] Overriding memory request: "
|
||
f"{float(managed_total_gb):.0f} GiB × {gpu_util} = "
|
||
f"{requested / (1024**3):.1f} GiB", file=sys.stderr)
|
||
return requested
|
||
else:
|
||
return _original_request_memory(init_snapshot, cache_config)
|
||
|
||
worker_utils.request_memory = patched_request_memory
|
||
print(f"[managed_mem] Patched vLLM memory check", file=sys.stderr)
|
||
|
||
|
||
def patch_vllm_memory_snapshot():
|
||
"""
|
||
Patch the memory snapshot to report managed-aware totals.
|
||
Without this, vLLM's worker sees ~95 GiB total and panics.
|
||
"""
|
||
import vllm.v1.worker.gpu_worker as gpu_worker
|
||
|
||
_original_init_device = gpu_worker.GpuWorker.init_device
|
||
|
||
def patched_init_device(self):
|
||
"""Patch init_device to override memory reporting."""
|
||
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||
if managed_total_gb:
|
||
# Set a permissive gpu_memory_utilization on the cache_config
|
||
# so the original code doesn't reject us
|
||
original_util = self.cache_config.gpu_memory_utilization
|
||
self.cache_config.gpu_memory_utilization = 0.5 # always passes
|
||
|
||
_original_init_device(self)
|
||
|
||
# Now fix up the requested_memory with the real value
|
||
total_bytes = float(managed_total_gb) * (1024**3)
|
||
self.requested_memory = int(total_bytes * original_util)
|
||
self.cache_config.gpu_memory_utilization = original_util
|
||
print(f"[managed_mem] Worker memory overridden to "
|
||
f"{self.requested_memory / (1024**3):.1f} GiB", file=sys.stderr)
|
||
else:
|
||
_original_init_device(self)
|
||
|
||
gpu_worker.GpuWorker.init_device = patched_init_device
|
||
print(f"[managed_mem] Patched vLLM worker init_device", file=sys.stderr)
|
||
|
||
|
||
def main():
|
||
# Step 1: Swap allocator BEFORE any CUDA ops
|
||
swap_allocator()
|
||
|
||
# Step 2: Calculate total managed memory and export it
|
||
egm_gb = get_total_managed_memory_gb()
|
||
total_managed_gb = egm_gb + 96 # EGM + HBM
|
||
os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb)
|
||
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
|
||
file=sys.stderr)
|
||
|
||
# Step 3: Patch vLLM memory checks
|
||
patch_vllm_memory_check()
|
||
patch_vllm_memory_snapshot()
|
||
|
||
# Step 4: Launch vLLM's API server with remaining args
|
||
sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:]
|
||
print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}",
|
||
file=sys.stderr)
|
||
|
||
# Import and run
|
||
from vllm.entrypoints.openai.api_server import run_server, FlexibleArgumentParser
|
||
import uvloop
|
||
|
||
parser = FlexibleArgumentParser(
|
||
description="vLLM OpenAI-compatible API server (managed memory)"
|
||
)
|
||
|
||
# Use vLLM's own argument parser
|
||
from vllm.entrypoints.openai.api_server import make_arg_parser
|
||
parser = make_arg_parser(parser)
|
||
args = parser.parse_args(sys.argv[1:])
|
||
|
||
uvloop.run(run_server(args))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|