#!/usr/bin/env python3 """ vllm_managed_mem.py - Launch vLLM with cudaMallocManaged allocator This MUST be the very first thing that runs before any torch.cuda calls. It swaps PyTorch's CUDA allocator to use cudaMallocManaged, enabling transparent page-fault access to EGM memory on GH200. Usage: python vllm_managed_mem.py [all normal vllm serve arguments] Example: python vllm_managed_mem.py --model google/gemma-4-31B-it \ --host 0.0.0.0 --port 80 --gpu-memory-utilization 0.90 \ --enforce-eager --max-model-len 32768 """ import os import sys import ctypes def get_total_managed_memory_gb(): """ Calculate total memory available via managed allocations on GH200. First checks MANAGED_MEMORY_TOTAL_GB env var (explicit override). Then tries /proc/iomem (requires --privileged container). Falls back to HBM only if neither works. """ # Explicit override takes priority env_val = os.environ.get('MANAGED_MEMORY_TOTAL_GB') if env_val: return float(env_val) # Try /proc/iomem for EGM regions (requires --privileged or host mount) egm_bytes = 0 try: with open('/proc/iomem', 'r') as f: for line in f: if 'System RAM (NVIDIA)' in line: parts = line.strip().split(':')[0].strip() start_s, end_s = parts.split('-') start = int(start_s, 16) end = int(end_s, 16) egm_bytes += (end - start + 1) except (PermissionError, FileNotFoundError, ValueError): pass egm_gb = egm_bytes / (1024**3) if egm_gb > 0: # EGM detected via iomem, add HBM import torch _, hbm_total = torch.cuda.mem_get_info(0) hbm_gb = hbm_total / (1024**3) return egm_gb + hbm_gb # No EGM detected - return 0 (caller should handle) return 0 def swap_allocator(): """ Replace PyTorch's default CUDA allocator with our managed memory allocator. This MUST happen before any CUDA tensors are created. If sitecustomize.py already swapped it, this is a no-op. """ lib_path = os.environ.get( 'MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so' ) if not os.path.exists(lib_path): print(f"[managed_mem] ERROR: {lib_path} not found!", file=sys.stderr) print(f"[managed_mem] Build it with: nvcc -shared -o {lib_path} " f"managed_alloc.cu -Xcompiler -fPIC", file=sys.stderr) sys.exit(1) # Verify the library loads try: lib = ctypes.CDLL(lib_path) assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found" assert hasattr(lib, 'managed_free'), "managed_free symbol not found" except Exception as e: print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr) sys.exit(1) import torch try: alloc = torch.cuda.memory.CUDAPluggableAllocator( lib_path, 'managed_malloc', 'managed_free' ) torch.cuda.memory.change_current_allocator(alloc) print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr) except RuntimeError as e: if "already initialized" in str(e): print(f"[managed_mem] Allocator already initialized (sitecustomize)", file=sys.stderr) else: raise def patch_memory_snapshot(): """ Patch MemorySnapshot.measure() to report the full managed memory instead of just HBM. This is the core fix: cudaMemGetInfo only reports HBM (~96 GiB), but with cudaMallocManaged we can address HBM + EGM (~474 GiB). We override the snapshot so all downstream code (request_memory, KV cache sizing, etc.) sees the full managed memory capacity. """ from vllm.utils.mem_utils import MemorySnapshot _original_measure = MemorySnapshot.measure def patched_measure(self): _original_measure(self) # Override total_memory with managed memory capacity managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB') if managed_total_gb: managed_total = int(float(managed_total_gb) * (1024**3)) self.total_memory = managed_total # free_memory = total - what CUDA is using self.free_memory = managed_total - self.cuda_memory print(f"[managed_mem] MemorySnapshot patched: " f"total={self.total_memory / (1024**3):.0f} GiB, " f"free={self.free_memory / (1024**3):.0f} GiB", file=sys.stderr) MemorySnapshot.measure = patched_measure print(f"[managed_mem] Patched MemorySnapshot.measure()", file=sys.stderr) def patch_torch_memory_tracking(): """ Patch PyTorch memory tracking functions that CUDAPluggableAllocator doesn't support. If sitecustomize already patched them, this is a no-op. """ import torch # Check if already patched by sitecustomize if torch.cuda.reset_peak_memory_stats.__name__ == '_patched_reset_peak_memory_stats': print(f"[managed_mem] torch.cuda already patched (sitecustomize)", file=sys.stderr) return _original_reset_peak = torch.cuda.reset_peak_memory_stats _original_memory_stats = torch.cuda.memory_stats _original_max_memory_allocated = torch.cuda.max_memory_allocated def _patched_reset_peak_memory_stats(device=None): """No-op: CUDAPluggableAllocator doesn't support resetPeakStats.""" pass def _patched_memory_stats(device=None): """Return minimal stats dict to avoid crashes.""" try: return _original_memory_stats(device) except RuntimeError: return { "allocated_bytes.all.current": 0, "allocated_bytes.all.peak": 0, "reserved_bytes.all.current": 0, "reserved_bytes.all.peak": 0, "num_alloc_retries": 0, "num_ooms": 0, } def _patched_max_memory_allocated(device=None): """Return 0 since we can't track peak with pluggable allocator.""" try: return _original_max_memory_allocated(device) except RuntimeError: return 0 torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats torch.cuda.memory_stats = _patched_memory_stats torch.cuda.max_memory_allocated = _patched_max_memory_allocated if hasattr(torch.accelerator, 'reset_peak_memory_stats'): torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats if hasattr(torch.accelerator, 'memory_stats'): torch.accelerator.memory_stats = _patched_memory_stats if hasattr(torch.accelerator, 'max_memory_allocated'): torch.accelerator.max_memory_allocated = _patched_max_memory_allocated print(f"[managed_mem] Patched torch.cuda memory tracking (no-op stubs)", file=sys.stderr) def patch_vllm_memory_check(): """ Monkey-patch vLLM's memory validation to understand managed memory. With managed memory, cudaMemGetInfo only reports HBM, so the free-memory check would fail. When MANAGED_MEMORY_TOTAL_GB is set, we skip that check. """ import vllm.v1.worker.utils as worker_utils _original_request_memory = worker_utils.request_memory def patched_request_memory(init_snapshot, cache_config): managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB') if managed_total_gb: total_bytes = float(managed_total_gb) * (1024**3) gpu_util = cache_config.gpu_memory_utilization requested = int(total_bytes * gpu_util) print(f"[managed_mem] Overriding memory request: " f"{float(managed_total_gb):.0f} GiB × {gpu_util} = " f"{requested / (1024**3):.1f} GiB", file=sys.stderr) return requested else: return _original_request_memory(init_snapshot, cache_config) worker_utils.request_memory = patched_request_memory print(f"[managed_mem] Patched vLLM request_memory", file=sys.stderr) def main(): # Step 1: NO global allocator swap — model weights stay in HBM. # KV cache uses cudaMallocManaged directly via # VLLM_KV_CACHE_USE_MANAGED_MEMORY env var (set by sitecustomize.py). # The global allocator swap broke cuBLAS GEMM operations because # intermediate compute tensors ended up in managed memory. print(f"[managed_mem] Using targeted KV cache managed allocation " f"(no global allocator swap)", file=sys.stderr) # Step 2: Calculate total managed memory and export it total_managed_gb = get_total_managed_memory_gb() if total_managed_gb <= 0: print(f"[managed_mem] WARNING: No managed memory detected! " f"Set MANAGED_MEMORY_TOTAL_GB env var explicitly.", file=sys.stderr) sys.exit(1) os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb) print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}", file=sys.stderr) # Step 3: No torch.cuda memory tracking patches needed — # we're not using CUDAPluggableAllocator anymore. # Step 4: Patch MemorySnapshot.measure() to report full managed memory # This is critical - without it, all downstream code only sees HBM patch_memory_snapshot() # Step 5: Patch request_memory as a safety net patch_vllm_memory_check() # Step 6: Launch vLLM's API server with remaining args sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:] print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}", file=sys.stderr) # Import and run from vllm.entrypoints.openai.api_server import run_server, FlexibleArgumentParser import uvloop parser = FlexibleArgumentParser( description="vLLM OpenAI-compatible API server (managed memory)" ) # Use vLLM's own argument parser from vllm.entrypoints.openai.api_server import make_arg_parser parser = make_arg_parser(parser) args = parser.parse_args(sys.argv[1:]) uvloop.run(run_server(args)) if __name__ == '__main__': main()