Add cudaMallocManaged allocator for GH200 EGM support

- managed_alloc.cu: PyTorch pluggable allocator using cudaMallocManaged
- vllm_managed_mem.py: Launcher that patches vLLM for managed memory
- Dockerfile: Build and install managed memory components

This enables vLLM to use cudaMallocManaged for transparent page-fault
access to both HBM (~96 GiB) and LPDDR (EGM, up to 480 GiB additional)
on GH200 systems with Extended GPU Memory enabled.

Experimental branch: v0.19.0-cmm
This commit is contained in:
2026-04-07 21:19:39 +00:00
parent edf12f7996
commit 2757bffcb6
3 changed files with 265 additions and 4 deletions

View File

@@ -1,9 +1,15 @@
# ==============================================================================
# Triton Kernels Build (TFA) - vLLM v0.19.0 + triton_kernels
# Managed Memory Build (CMM) - vLLM with cudaMallocManaged for GH200 EGM
# ==============================================================================
# This branch adds triton_kernels from Triton v3.6.0 for MoE support.
# This branch adds cudaMallocManaged allocator support for GH200 systems with
# Extended GPU Memory (EGM). This enables transparent page-fault access to
# both HBM (~96 GiB) and LPDDR (EGM, up to 480 GiB additional).
#
# Based on working Build #43 (v0.18.2rc0) with vLLM upgraded to v0.19.0:
# Key components:
# - managed_alloc.cu: PyTorch pluggable allocator using cudaMallocManaged
# - vllm_managed_mem.py: Launcher that patches vLLM for managed memory
#
# Based on working Build #48 (v0.19.0):
# - vLLM: v0.19.0
# - flashinfer: v0.6.6
# - flash-attention: hopper branch
@@ -19,7 +25,7 @@
# 3. CLEAR ALL CHANGES WITH MIKE BEFORE MAKING THEM
# 4. ONE BUILD AT A TIME - Mike reports failure → I assess → I report
#
# Image tag: gh200-vllm-tfa:v0.19.0-tfa
# Image tag: gh200-vllm-cmm:v0.19.0-cmm
# ==============================================================================
# ---------- Builder Base ----------
@@ -235,6 +241,23 @@ RUN apt install -y --no-install-recommends tmux cmake
# Deprecated cleanup
RUN pip uninstall -y pynvml && pip install nvidia-ml-py
# ==============================================================================
# Managed Memory Allocator (cudaMallocManaged for GH200 EGM)
# ==============================================================================
# This enables vLLM to use cudaMallocManaged for transparent page-fault
# access to both HBM and LPDDR (EGM) memory on GH200 systems.
#
# The managed_alloc.cu provides a PyTorch pluggable allocator that uses
# cudaMallocManaged instead of cudaMalloc. vllm_managed_mem.py is a
# launcher that swaps the allocator before any CUDA operations and patches
# vLLM's memory validation to understand the larger managed memory space.
# ==============================================================================
COPY managed_alloc.cu /tmp/managed_alloc.cu
RUN nvcc -shared -o /usr/local/lib/libmanaged_alloc.so \
/tmp/managed_alloc.cu -Xcompiler -fPIC && rm /tmp/managed_alloc.cu
COPY vllm_managed_mem.py /usr/local/bin/vllm_managed_mem.py
RUN chmod +x /usr/local/bin/vllm_managed_mem.py
# API server entrypoint
# ENTRYPOINT ["vllm", "serve"]
CMD ["/bin/bash"]

48
vllm/managed_alloc.cu Normal file
View File

@@ -0,0 +1,48 @@
// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
#include <cuda_runtime.h>
#include <stdio.h>
extern "C" {
// PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t)
void* managed_malloc(size_t size, int device, cudaStream_t stream) {
void* ptr = nullptr;
// Set the device before allocating
cudaError_t err = cudaSetDevice(device);
if (err != cudaSuccess) {
fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n",
device, cudaGetErrorString(err));
return nullptr;
}
// Use cudaMallocManaged - this is the key: allocations can page-fault
// across HBM and LPDDR on GH200 with EGM enabled
err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal);
if (err != cudaSuccess) {
fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s "
"(size=%zu bytes / %.2f GiB)\n",
cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0));
return nullptr;
}
// Advise the driver to prefer GPU placement initially.
// On GH200 with EGM, the hardware will migrate pages as needed.
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, device);
return ptr;
}
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
if (ptr != nullptr) {
// Sync the stream before freeing to avoid use-after-free
if (stream != nullptr) {
cudaStreamSynchronize(stream);
}
cudaFree(ptr);
}
}
} // extern "C"

190
vllm/vllm_managed_mem.py Normal file
View File

@@ -0,0 +1,190 @@
#!/usr/bin/env python3
"""
vllm_managed_mem.py - Launch vLLM with cudaMallocManaged allocator
This MUST be the very first thing that runs before any torch.cuda calls.
It swaps PyTorch's CUDA allocator to use cudaMallocManaged, enabling
transparent page-fault access to EGM memory on GH200.
Usage:
python vllm_managed_mem.py [all normal vllm serve arguments]
Example:
python vllm_managed_mem.py --model google/gemma-4-31B-it \
--host 0.0.0.0 --port 80 --gpu-memory-utilization 0.90 \
--enforce-eager --max-model-len 32768
"""
import os
import sys
import ctypes
def get_total_managed_memory_gb():
"""
Calculate total memory available via managed allocations on GH200.
Parses /proc/iomem to find NVIDIA EGM regions + HBM.
"""
egm_bytes = 0
try:
with open('/proc/iomem', 'r') as f:
for line in f:
if 'System RAM (NVIDIA)' in line:
parts = line.strip().split(':')[0].strip()
start_s, end_s = parts.split('-')
start = int(start_s, 16)
end = int(end_s, 16)
egm_bytes += (end - start + 1)
except (PermissionError, FileNotFoundError):
pass
egm_gb = egm_bytes / (1024**3)
# HBM is always there via normal cudaMalloc path
# cudaMallocManaged can span both HBM + EGM
return egm_gb
def swap_allocator():
"""
Replace PyTorch's default CUDA allocator with our managed memory allocator.
This MUST happen before any CUDA tensors are created.
"""
lib_path = os.environ.get(
'MANAGED_ALLOC_LIB',
'/usr/local/lib/libmanaged_alloc.so'
)
if not os.path.exists(lib_path):
print(f"[managed_mem] ERROR: {lib_path} not found!", file=sys.stderr)
print(f"[managed_mem] Build it with: nvcc -shared -o {lib_path} "
f"managed_alloc.cu -Xcompiler -fPIC", file=sys.stderr)
sys.exit(1)
# Verify the library loads
try:
lib = ctypes.CDLL(lib_path)
assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found"
assert hasattr(lib, 'managed_free'), "managed_free symbol not found"
print(f"[managed_mem] Loaded allocator from {lib_path}", file=sys.stderr)
except Exception as e:
print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr)
sys.exit(1)
import torch
# This must happen before ANY cuda operation
alloc = torch.cuda.memory.CUDAPluggableAllocator(
lib_path, 'managed_malloc', 'managed_free'
)
torch.cuda.memory.change_current_allocator(alloc)
egm_gb = get_total_managed_memory_gb()
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
print(f"[managed_mem] Detected ~{egm_gb:.0f} GiB EGM memory", file=sys.stderr)
print(f"[managed_mem] Total addressable: ~{egm_gb + 96:.0f} GiB "
f"(EGM + HBM)", file=sys.stderr)
def patch_vllm_memory_check():
"""
Monkey-patch vLLM's memory validation to understand managed memory.
vLLM checks cudaMemGetInfo and rejects startup if free < requested.
With managed memory, the real capacity is much larger than what
cudaMemGetInfo reports (it only sees HBM + a tiny EGM sliver).
"""
import vllm.v1.worker.utils as worker_utils
_original_request_memory = worker_utils.request_memory
def patched_request_memory(init_snapshot, cache_config):
"""
Override memory validation:
- Read the MANAGED_MEMORY_TOTAL_GB env var (set by us)
- If set, skip the free-memory check and just return the requested size
"""
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
total_bytes = float(managed_total_gb) * (1024**3)
gpu_util = cache_config.gpu_memory_utilization
requested = int(total_bytes * gpu_util)
print(f"[managed_mem] Overriding memory request: "
f"{float(managed_total_gb):.0f} GiB × {gpu_util} = "
f"{requested / (1024**3):.1f} GiB", file=sys.stderr)
return requested
else:
return _original_request_memory(init_snapshot, cache_config)
worker_utils.request_memory = patched_request_memory
print(f"[managed_mem] Patched vLLM memory check", file=sys.stderr)
def patch_vllm_memory_snapshot():
"""
Patch the memory snapshot to report managed-aware totals.
Without this, vLLM's worker sees ~95 GiB total and panics.
"""
import vllm.v1.worker.gpu_worker as gpu_worker
_original_init_device = gpu_worker.GpuWorker.init_device
def patched_init_device(self):
"""Patch init_device to override memory reporting."""
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
# Set a permissive gpu_memory_utilization on the cache_config
# so the original code doesn't reject us
original_util = self.cache_config.gpu_memory_utilization
self.cache_config.gpu_memory_utilization = 0.5 # always passes
_original_init_device(self)
# Now fix up the requested_memory with the real value
total_bytes = float(managed_total_gb) * (1024**3)
self.requested_memory = int(total_bytes * original_util)
self.cache_config.gpu_memory_utilization = original_util
print(f"[managed_mem] Worker memory overridden to "
f"{self.requested_memory / (1024**3):.1f} GiB", file=sys.stderr)
else:
_original_init_device(self)
gpu_worker.GpuWorker.init_device = patched_init_device
print(f"[managed_mem] Patched vLLM worker init_device", file=sys.stderr)
def main():
# Step 1: Swap allocator BEFORE any CUDA ops
swap_allocator()
# Step 2: Calculate total managed memory and export it
egm_gb = get_total_managed_memory_gb()
total_managed_gb = egm_gb + 96 # EGM + HBM
os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb)
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
file=sys.stderr)
# Step 3: Patch vLLM memory checks
patch_vllm_memory_check()
patch_vllm_memory_snapshot()
# Step 4: Launch vLLM's API server with remaining args
sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:]
print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}",
file=sys.stderr)
# Import and run
from vllm.entrypoints.openai.api_server import run_server, FlexibleArgumentParser
import uvloop
parser = FlexibleArgumentParser(
description="vLLM OpenAI-compatible API server (managed memory)"
)
# Use vLLM's own argument parser
from vllm.entrypoints.openai.api_server import make_arg_parser
parser = make_arg_parser(parser)
args = parser.parse_args(sys.argv[1:])
uvloop.run(run_server(args))
if __name__ == '__main__':
main()