Compare commits
12 Commits
main
...
cuda-mallo
| Author | SHA1 | Date | |
|---|---|---|---|
| be4198e754 | |||
| bcc872c2c3 | |||
| 07468031db | |||
| cdfd37c1e6 | |||
| c1b013234e | |||
| 98b4ae6676 | |||
| c583bcb4fc | |||
| 6053e6d0ea | |||
| aadde3ddf9 | |||
| 079eb88d7d | |||
| 7c79fb4ee7 | |||
| 2757bffcb6 |
163
README.md
163
README.md
@@ -1,33 +1,152 @@
|
|||||||
# Building containers for GH200
|
# Building vLLM containers for GH200 leveraging cudaMallocManaged when EGM (Extended GPU Memory) is enabled
|
||||||
|
|
||||||
Currently, prebuilt wheels for `vLLM` and `LMcache` are not available for `aarch64`. This can make setup tedious when working on modern `aarch64` platforms such as NVIDIA GH200.
|
## The Problem
|
||||||
|
|
||||||
Further, Nvidia at this time does not provide the `Dockerfile` associated with the NGC containers which makes replacing some of the components (like a newer version of vLLM) tedious.
|
The GH200 has 96 GiB of HBM (VRAM) and 480 GiB of LPDDR (system memory). The only way for the GPU to access system memory over the C2C NVLink at the full 900 GB/s — without going through the IOMMU — is to enable EGM in the BIOS.
|
||||||
|
|
||||||
This repository provides a Dockerfile to build a container with vLLM and all its dependencies pre-installed to try out various things such as KV offloading.
|
This creates three issues:
|
||||||
|
|
||||||
If you prefer not to build the image yourself, you can pull the ready-to-use image directly from Docker Hub:
|
1. **EGM requires reserved memory value of 8192** — this is the only value that works
|
||||||
|
2. **The server loses most of its system memory** — no longer sees 480 GiB; instead sees ~102 GiB. The rest has been handed over to the GPU
|
||||||
|
3. **vLLM still only sees 96 GiB of VRAM** — the GPU now has access to the full memory space, but the only way to leverage it is to convert all `cudaMalloc` calls to `cudaMallocManaged`. Without this, vLLM's allocator only touches HBM
|
||||||
|
|
||||||
```bash
|
## The Goal
|
||||||
docker run --rm -it --gpus all -v "$PWD":"$PWD" -w "$PWD" rajesh550/gh200-vllm:0.11.0 bash
|
|
||||||
|
|
||||||
# CUDA 13
|
Force vLLM to use `cudaMallocManaged` so it can address the full memory space (HBM + EGM). Before we can do that, we have to make sure vLLM's preflight checks of available VRAM show the new fully allocated amount (~97 GiB HBM + ~378 GiB EGM = ~475 GiB total managed memory).
|
||||||
docker run --rm -it --gpus all -v "$PWD":"$PWD" -w "$PWD" rajesh550/gh200-vllm:0.11.1rc2 bash
|
|
||||||
|
## GH200 System State with EGM Enabled
|
||||||
|
|
||||||
|
```
|
||||||
|
$ nvidia-smi
|
||||||
|
+-----------------------------------------------------------------------------------------+
|
||||||
|
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
|
||||||
|
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
|
||||||
|
|=========================================================================================|
|
||||||
|
| 0 NVIDIA GH200 480GB On | 00000009:01:00.0 Off | Off |
|
||||||
|
| N/A 31C P0 83W / 700W | 14920MiB / 97871MiB | 0% Default |
|
||||||
|
+-----------------------------------------------------------------------------------------+
|
||||||
|
|
||||||
|
$ free -h
|
||||||
|
total used free shared buff/cache available
|
||||||
|
Mem: 102Gi 4.2Gi 83Gi 10Mi 15Gi 93Gi
|
||||||
|
|
||||||
|
$ dmidecode -t memory | grep -i size
|
||||||
|
Size: 480 GB
|
||||||
|
|
||||||
|
$ numactl --hardware
|
||||||
|
available: 10 nodes (0-9)
|
||||||
|
node 0 cpus: 0-71
|
||||||
|
node 0 size: 8080 MB
|
||||||
|
node 2 size: 97280 MB
|
||||||
```
|
```
|
||||||
|
|
||||||
👉 [Docker Hub](https://hub.docker.com/repository/docker/rajesh550/gh200-vllm/general)
|
Key observations:
|
||||||
|
- **nvidia-smi** reports 97,871 MiB (~96 GiB) — this is HBM only, NOT the EGM
|
||||||
|
- **OS** sees only ~102 GiB system RAM — EGM carved out 378 GiB for the GPU
|
||||||
|
- **dmidecode** confirms the physical DIMM is 480 GiB
|
||||||
|
- **NUMA** node 2 has 97 GiB (the LPDDR that wasn't handed to EGM), node 0 has 8 GiB (local)
|
||||||
|
- The EGM memory appears as `System RAM (NVIDIA)` in `/proc/iomem` at addresses `0x400000000000+`
|
||||||
|
|
||||||
Version info:
|
## Architecture
|
||||||
|
|
||||||
|
### Managed Memory Allocator
|
||||||
|
|
||||||
|
The approach uses a PyTorch pluggable allocator (`managed_alloc.cu`) that replaces `cudaMalloc` with `cudaMallocManaged`, enabling transparent page-fault access to both HBM and EGM.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────┐
|
||||||
|
│ vLLM │
|
||||||
|
│ (sees full managed memory) │
|
||||||
|
├─────────────────────────────────────────────────┤
|
||||||
|
│ PyTorch Pluggable Allocator │
|
||||||
|
│ (managed_alloc.cu / .so) │
|
||||||
|
│ cudaMallocManaged → unified memory pool │
|
||||||
|
├─────────────────────────────────────────────────┤
|
||||||
|
│ CUDA Unified Memory │
|
||||||
|
├──────────────────┬──────────────────────────────┤
|
||||||
|
│ HBM (~96 GiB) │ EGM (~378 GiB) │
|
||||||
|
│ Fast / local │ Page-fault over C2C NVLink │
|
||||||
|
│ │ 900 GB/s bandwidth │
|
||||||
|
└──────────────────┴──────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Launcher
|
||||||
|
|
||||||
|
`vllm_managed_mem.py` is the entry point that:
|
||||||
|
1. Loads the managed allocator `.so` before any CUDA operations
|
||||||
|
2. Swaps PyTorch's default allocator to `cudaMallocManaged`
|
||||||
|
3. Patches vLLM's memory validation to understand the larger managed memory space
|
||||||
|
|
||||||
|
## Source Repos
|
||||||
|
|
||||||
|
| Repo | URL | Branch | Purpose |
|
||||||
|
|------|-----|--------|---------|
|
||||||
|
| vLLM (our fork) | `https://sweetapi.com/biondizzle/vllm.git` | `cmm` | vLLM with cudaMallocManaged patches |
|
||||||
|
| grace-gpu-containers | `https://sweetapi.com/biondizzle/grace-gpu-containers.git` | `cuda-malloc-managed` | Dockerfile + build pipeline |
|
||||||
|
|
||||||
|
The `cmm` branch in our vLLM fork is based on tag `v0.19.0` (commit `2a69949bd`).
|
||||||
|
|
||||||
|
## Build Pipeline
|
||||||
|
|
||||||
|
### Jenkins: `gh200-vllm-build-cmm`
|
||||||
|
|
||||||
|
The build chain:
|
||||||
|
|
||||||
|
```
|
||||||
|
Jenkins → grace-gpu-containers (cuda-malloc-managed branch)
|
||||||
|
→ Dockerfile clones vLLM from Gitea fork (cmm branch)
|
||||||
|
→ Builds ARM64 image via buildx on remote GH200
|
||||||
|
→ Pushes to atl.vultrcr.com/vllm
|
||||||
|
```
|
||||||
|
|
||||||
|
**Default parameters:**
|
||||||
|
- `VLLM_VERSION`: `cmm` (our fork branch)
|
||||||
|
- `IMAGE_TAG`: `gh200-vllm-cmm`
|
||||||
|
- `IMAGE_SUFFIX`: `-cmm`
|
||||||
|
|
||||||
|
**Image:** `atl.vultrcr.com/vllm/gh200-vllm-cmm:cmm-cmm`
|
||||||
|
|
||||||
|
### Dockerfile Build Stages
|
||||||
|
|
||||||
|
| Stage | What it builds | Source |
|
||||||
|
|-------|---------------|--------|
|
||||||
|
| build-triton | Triton 3.6.0 | PyPI wheel (aarch64) |
|
||||||
|
| build-triton-kernels | triton_kernels v3.6.0 | Triton repo |
|
||||||
|
| build-flashinfer | flashinfer v0.6.6 | Source (apache-tvm-ffi required) |
|
||||||
|
| build-lmcache | LMCache dev | Source |
|
||||||
|
| build-flash-attention | FlashAttention hopper | Source |
|
||||||
|
| build-vllm | vLLM cmm branch | Our Gitea fork |
|
||||||
|
| build-infinistore | InfiniStore | Source |
|
||||||
|
|
||||||
|
**Base image:** `nvcr.io/nvidia/pytorch:26.03-py3`
|
||||||
|
- PyTorch 2.11.0a0, CUDA 13.2.0
|
||||||
|
- Multi-arch: x86 + ARM SBSA (GH200)
|
||||||
|
- Target: `9.0a` (Hopper)
|
||||||
|
|
||||||
|
## Key Files
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `vllm/Dockerfile` | Multi-stage build for the CMM container |
|
||||||
|
| `vllm/managed_alloc.cu` | PyTorch pluggable allocator using `cudaMallocManaged` |
|
||||||
|
| `vllm/vllm_managed_mem.py` | Launcher that patches vLLM for managed memory |
|
||||||
|
| `lmcache/Dockerfile` | Standalone LMCache build |
|
||||||
|
|
||||||
|
## Local Development
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
CUDA: 13.0.1
|
# vLLM fork (working directory)
|
||||||
Ubuntu: 24.04
|
cd /home/openclaw/dev/vllm
|
||||||
Python: 3.12
|
git checkout cmm # our working branch, based on v0.19.0
|
||||||
PyTorch: 2.9.0+cu130
|
|
||||||
Triton: 3.5.x
|
# grace-gpu-containers
|
||||||
xformers: 0.32.post2+
|
cd /home/openclaw/dev/grace-gpu-containers
|
||||||
flashinfer: 0.4.1
|
git checkout cuda-malloc-managed # CMM Dockerfile lives here
|
||||||
flashattention: 3.0.0b1
|
```
|
||||||
LMCache: 0.3.7
|
|
||||||
vLLM: 0.11.1rc3
|
## Hard Rules
|
||||||
```
|
|
||||||
|
1. **No downgrades** — CUDA 13+, PyTorch 2.9+, vLLM 0.18.1+
|
||||||
|
2. **No skipping compilation** — build from source
|
||||||
|
3. **Clear all changes with Mike before making them** — no autonomous commits
|
||||||
|
4. **One build at a time** — Mike reports failure → assess → report → Mike decides
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Triton Kernels Build (TFA) - vLLM v0.19.0 + triton_kernels
|
# Managed Memory Build (CMM) - vLLM with cudaMallocManaged for GH200 EGM
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# This branch adds triton_kernels from Triton v3.6.0 for MoE support.
|
# This branch adds cudaMallocManaged allocator support for GH200 systems with
|
||||||
|
# Extended GPU Memory (EGM). This enables transparent page-fault access to
|
||||||
|
# both HBM (~96 GiB) and LPDDR (EGM, up to 480 GiB additional).
|
||||||
#
|
#
|
||||||
# Based on working Build #43 (v0.18.2rc0) with vLLM upgraded to v0.19.0:
|
# Key components:
|
||||||
# - vLLM: v0.19.0
|
# - managed_alloc.cu: PyTorch pluggable allocator using cudaMallocManaged
|
||||||
|
# - vllm_managed_mem.py: Launcher that patches vLLM for managed memory
|
||||||
|
#
|
||||||
|
# Based on working Build #48 (v0.19.0):
|
||||||
|
# - vLLM: v0.19.0 (forked to sweetapi.com/biondizzle/vllm, cmm branch)
|
||||||
# - flashinfer: v0.6.6
|
# - flashinfer: v0.6.6
|
||||||
# - flash-attention: hopper branch
|
# - flash-attention: hopper branch
|
||||||
# - lmcache: dev branch
|
# - lmcache: dev branch
|
||||||
@@ -19,7 +25,7 @@
|
|||||||
# 3. CLEAR ALL CHANGES WITH MIKE BEFORE MAKING THEM
|
# 3. CLEAR ALL CHANGES WITH MIKE BEFORE MAKING THEM
|
||||||
# 4. ONE BUILD AT A TIME - Mike reports failure → I assess → I report
|
# 4. ONE BUILD AT A TIME - Mike reports failure → I assess → I report
|
||||||
#
|
#
|
||||||
# Image tag: gh200-vllm-tfa:v0.19.0-tfa
|
# Image tag: gh200-vllm-cmm:v0.19.0-cmm
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
# ---------- Builder Base ----------
|
# ---------- Builder Base ----------
|
||||||
@@ -150,11 +156,13 @@ RUN apt-get update && apt-get install -y build-essential cmake gcc && \
|
|||||||
# Build vLLM from source
|
# Build vLLM from source
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
FROM build-base AS build-vllm
|
FROM build-base AS build-vllm
|
||||||
# vLLM version to build
|
# vLLM version/branch to build
|
||||||
ARG VLLM_REF=v0.19.0
|
# Using our Gitea fork (sweetapi.com/biondizzle/vllm) on the cmm branch
|
||||||
|
ARG VLLM_REF=cmm
|
||||||
|
ARG VLLM_COMMIT=latest
|
||||||
# Install ccache for faster compilation
|
# Install ccache for faster compilation
|
||||||
RUN apt-get update && apt-get install -y ccache
|
RUN apt-get update && apt-get install -y ccache
|
||||||
RUN git clone https://github.com/vllm-project/vllm.git
|
RUN echo "VLLM_COMMIT=${VLLM_COMMIT}" && git clone https://sweetapi.com/biondizzle/vllm.git
|
||||||
RUN cd vllm && \
|
RUN cd vllm && \
|
||||||
git checkout ${VLLM_REF} && \
|
git checkout ${VLLM_REF} && \
|
||||||
echo "\n\n========================================" && \
|
echo "\n\n========================================" && \
|
||||||
@@ -235,6 +243,38 @@ RUN apt install -y --no-install-recommends tmux cmake
|
|||||||
# Deprecated cleanup
|
# Deprecated cleanup
|
||||||
RUN pip uninstall -y pynvml && pip install nvidia-ml-py
|
RUN pip uninstall -y pynvml && pip install nvidia-ml-py
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
# Managed Memory Allocator (cudaMallocManaged for GH200 EGM)
|
||||||
|
# ==============================================================================
|
||||||
|
# This enables vLLM to use cudaMallocManaged for transparent page-fault
|
||||||
|
# access to both HBM and LPDDR (EGM) memory on GH200 systems.
|
||||||
|
#
|
||||||
|
# The managed_alloc.cu provides a PyTorch pluggable allocator that uses
|
||||||
|
# cudaMallocManaged instead of cudaMalloc. Key features:
|
||||||
|
# - cudaMemAdviseSetPreferredLocation(GPU): keep pages on GPU
|
||||||
|
# - cudaMemAdviseSetAccessedBy(CPU): CPU reads over C2C NVLink without
|
||||||
|
# migrating pages back to system RAM (prevents OOM on EGM systems)
|
||||||
|
# - cudaMemPrefetchAsync(GPU): actively migrates pages to GPU immediately,
|
||||||
|
# so model weight writes go to HBM/EGM, not system RAM
|
||||||
|
#
|
||||||
|
# vllm_managed_mem.py is the launcher that patches vLLM's memory validation
|
||||||
|
# to understand the larger managed memory space. No global allocator swap.
|
||||||
|
#
|
||||||
|
# sitecustomize.py is auto-loaded by Python in ALL subprocesses (including
|
||||||
|
# vLLM's EngineCore). It sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 and patches
|
||||||
|
# MemorySnapshot/request_memory — does NOT swap the global allocator.
|
||||||
|
# ==============================================================================
|
||||||
|
ARG CMM_BUILD_DATE=default
|
||||||
|
RUN echo "CMM build: ${CMM_BUILD_DATE}" > /dev/null
|
||||||
|
COPY managed_alloc.cu /tmp/managed_alloc.cu
|
||||||
|
RUN nvcc -shared -o /usr/local/lib/libmanaged_alloc.so \
|
||||||
|
/tmp/managed_alloc.cu -Xcompiler -fPIC && rm /tmp/managed_alloc.cu
|
||||||
|
COPY vllm_managed_mem.py /usr/local/bin/vllm_managed_mem.py
|
||||||
|
RUN chmod +x /usr/local/bin/vllm_managed_mem.py
|
||||||
|
# sitecustomize.py is auto-loaded by Python before any other imports.
|
||||||
|
# This ensures CMM patches apply in ALL subprocesses (EngineCore, etc.)
|
||||||
|
COPY sitecustomize.py /usr/local/lib/python3.12/dist-packages/sitecustomize.py
|
||||||
|
|
||||||
# API server entrypoint
|
# API server entrypoint
|
||||||
# ENTRYPOINT ["vllm", "serve"]
|
# ENTRYPOINT ["vllm", "serve"]
|
||||||
CMD ["/bin/bash"]
|
CMD ["/bin/bash"]
|
||||||
|
|||||||
96
vllm/managed_alloc.cu
Normal file
96
vllm/managed_alloc.cu
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
|
||||||
|
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
|
||||||
|
// Compatible with CUDA 13+ (uses cudaMemLocation API)
|
||||||
|
//
|
||||||
|
// Key design decisions for GH200 EGM:
|
||||||
|
// 1. cudaMallocManaged → allocations can page-fault across HBM + EGM
|
||||||
|
// 2. cudaMemAdviseSetPreferredLocation(GPU) → driver prefers keeping pages on GPU
|
||||||
|
// 3. cudaMemAdviseSetAccessedBy(CPU) → CPU can access over C2C NVLink without
|
||||||
|
// triggering page migration back to system RAM (critical: prevents OOM)
|
||||||
|
// 4. Selective prefetching — small allocations (model weights, <2 GiB)
|
||||||
|
// are prefetched to GPU so cuBLAS/cuDNN kernels can access them
|
||||||
|
// directly from HBM. Large allocations (KV cache blocks) stay in
|
||||||
|
// managed memory and page-fault on demand, since they're too large
|
||||||
|
// to fit in HBM and attention ops can tolerate page faults.
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
// PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t)
|
||||||
|
void* managed_malloc(size_t size, int device, cudaStream_t stream) {
|
||||||
|
void* ptr = nullptr;
|
||||||
|
|
||||||
|
// Set the device before allocating
|
||||||
|
cudaError_t err = cudaSetDevice(device);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n",
|
||||||
|
device, cudaGetErrorString(err));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use cudaMallocManaged - this is the key: allocations can page-fault
|
||||||
|
// across HBM and LPDDR on GH200 with EGM enabled
|
||||||
|
err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s "
|
||||||
|
"(size=%zu bytes / %.2f GiB)\n",
|
||||||
|
cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0));
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// CUDA 13+ uses cudaMemLocation struct instead of int for device
|
||||||
|
cudaMemLocation gpu_loc;
|
||||||
|
gpu_loc.type = cudaMemLocationTypeDevice;
|
||||||
|
gpu_loc.id = device;
|
||||||
|
|
||||||
|
// Advise: prefer GPU placement. On GH200 with EGM, the hardware will
|
||||||
|
// migrate pages as needed, but the driver tries to keep them on GPU.
|
||||||
|
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, gpu_loc);
|
||||||
|
|
||||||
|
// Advise: CPU will access this memory too. On GH200, this sets up
|
||||||
|
// remote mapping over C2C NVLink so CPU can read/write without
|
||||||
|
// triggering page migration back to system RAM. This is CRITICAL
|
||||||
|
// to prevent OOM on EGM systems where most system RAM was carved
|
||||||
|
// out for the GPU.
|
||||||
|
cudaMemLocation cpu_loc;
|
||||||
|
cpu_loc.type = cudaMemLocationTypeHost;
|
||||||
|
cpu_loc.id = cudaCpuDeviceId;
|
||||||
|
cudaMemAdvise(ptr, size, cudaMemAdviseSetAccessedBy, cpu_loc);
|
||||||
|
|
||||||
|
// Selective prefetch: migrate pages to GPU for small allocations only.
|
||||||
|
// Model weights (individual tensors) are typically <2 GiB and MUST be
|
||||||
|
// on GPU for cuBLAS GEMM operations — GPU compute kernels cannot
|
||||||
|
// page-fault into managed memory during execution.
|
||||||
|
// KV cache blocks are large and numerous; prefetching them all fills
|
||||||
|
// HBM and causes subsequent allocations to fail.
|
||||||
|
// The 2 GiB threshold separates "compute data" from "cache data".
|
||||||
|
const size_t PREFETCH_THRESHOLD = 2ULL * 1024 * 1024 * 1024; // 2 GiB
|
||||||
|
|
||||||
|
if (size > 0 && size < PREFETCH_THRESHOLD) {
|
||||||
|
err = cudaMemPrefetchAsync(ptr, size, gpu_loc, 0);
|
||||||
|
if (err != cudaSuccess) {
|
||||||
|
// Non-fatal: prefetch failure shouldn't prevent allocation.
|
||||||
|
// Pages will still be migrated on demand.
|
||||||
|
fprintf(stderr, "[managed_alloc] cudaMemPrefetchAsync warning: %s "
|
||||||
|
"(size=%.2f GiB, will use on-demand migration)\n",
|
||||||
|
cudaGetErrorString(err), (double)size / (1024.0*1024.0*1024.0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
|
||||||
|
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
|
||||||
|
if (ptr != nullptr) {
|
||||||
|
// Sync the stream before freeing to avoid use-after-free with
|
||||||
|
// managed memory (in-flight page faults can race with deallocation).
|
||||||
|
if (stream != nullptr) {
|
||||||
|
cudaStreamSynchronize(stream);
|
||||||
|
}
|
||||||
|
cudaFree(ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // extern "C"
|
||||||
62
vllm/sitecustomize.py
Normal file
62
vllm/sitecustomize.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""
|
||||||
|
sitecustomize.py - Auto-loaded by Python before any other imports.
|
||||||
|
|
||||||
|
In CMM mode (MANAGED_MEMORY_TOTAL_GB set):
|
||||||
|
- Patches MemorySnapshot.measure() to report managed memory capacity
|
||||||
|
- Patches request_memory to calculate KV cache size based on managed memory
|
||||||
|
- Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 so KV cache uses cudaMallocManaged
|
||||||
|
|
||||||
|
Does NOT swap the global CUDA allocator — model weights and compute
|
||||||
|
intermediates use normal cudaMalloc in HBM. Only KV cache spills into
|
||||||
|
EGM via cudaMallocManaged, called directly from gpu_model_runner.py.
|
||||||
|
|
||||||
|
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Only activate in CMM mode
|
||||||
|
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||||
|
if _MANAGED_TOTAL:
|
||||||
|
# Enable KV cache managed memory allocation in gpu_model_runner.py
|
||||||
|
os.environ['VLLM_KV_CACHE_USE_MANAGED_MEMORY'] = '1'
|
||||||
|
|
||||||
|
# Patch MemorySnapshot.measure() to report managed memory capacity
|
||||||
|
# This tells vLLM how much total memory is available for KV cache sizing
|
||||||
|
try:
|
||||||
|
from vllm.utils.mem_utils import MemorySnapshot
|
||||||
|
_original_measure = MemorySnapshot.measure
|
||||||
|
|
||||||
|
def _patched_measure(self):
|
||||||
|
_original_measure(self)
|
||||||
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||||
|
if managed_total_gb:
|
||||||
|
managed_total = int(float(managed_total_gb) * (1024**3))
|
||||||
|
self.total_memory = managed_total
|
||||||
|
self.free_memory = managed_total - self.cuda_memory
|
||||||
|
|
||||||
|
MemorySnapshot.measure = _patched_measure
|
||||||
|
except ImportError:
|
||||||
|
pass # vllm not loaded yet, will be patched later
|
||||||
|
|
||||||
|
# Patch request_memory to calculate based on managed memory
|
||||||
|
try:
|
||||||
|
import vllm.v1.worker.utils as worker_utils
|
||||||
|
_original_request_memory = worker_utils.request_memory
|
||||||
|
|
||||||
|
def _patched_request_memory(init_snapshot, cache_config):
|
||||||
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||||
|
if managed_total_gb:
|
||||||
|
total_bytes = float(managed_total_gb) * (1024**3)
|
||||||
|
gpu_util = cache_config.gpu_memory_utilization
|
||||||
|
requested = int(total_bytes * gpu_util)
|
||||||
|
return requested
|
||||||
|
else:
|
||||||
|
return _original_request_memory(init_snapshot, cache_config)
|
||||||
|
|
||||||
|
worker_utils.request_memory = _patched_request_memory
|
||||||
|
except ImportError:
|
||||||
|
pass # vllm not loaded yet
|
||||||
|
|
||||||
|
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB, "
|
||||||
|
f"KV cache will use cudaMallocManaged)", file=sys.stderr)
|
||||||
271
vllm/vllm_managed_mem.py
Normal file
271
vllm/vllm_managed_mem.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
vllm_managed_mem.py - Launch vLLM with cudaMallocManaged allocator
|
||||||
|
|
||||||
|
This MUST be the very first thing that runs before any torch.cuda calls.
|
||||||
|
It swaps PyTorch's CUDA allocator to use cudaMallocManaged, enabling
|
||||||
|
transparent page-fault access to EGM memory on GH200.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python vllm_managed_mem.py [all normal vllm serve arguments]
|
||||||
|
|
||||||
|
Example:
|
||||||
|
python vllm_managed_mem.py --model google/gemma-4-31B-it \
|
||||||
|
--host 0.0.0.0 --port 80 --gpu-memory-utilization 0.90 \
|
||||||
|
--enforce-eager --max-model-len 32768
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import ctypes
|
||||||
|
|
||||||
|
|
||||||
|
def get_total_managed_memory_gb():
|
||||||
|
"""
|
||||||
|
Calculate total memory available via managed allocations on GH200.
|
||||||
|
First checks MANAGED_MEMORY_TOTAL_GB env var (explicit override).
|
||||||
|
Then tries /proc/iomem (requires --privileged container).
|
||||||
|
Falls back to HBM only if neither works.
|
||||||
|
"""
|
||||||
|
# Explicit override takes priority
|
||||||
|
env_val = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||||
|
if env_val:
|
||||||
|
return float(env_val)
|
||||||
|
|
||||||
|
# Try /proc/iomem for EGM regions (requires --privileged or host mount)
|
||||||
|
egm_bytes = 0
|
||||||
|
try:
|
||||||
|
with open('/proc/iomem', 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
if 'System RAM (NVIDIA)' in line:
|
||||||
|
parts = line.strip().split(':')[0].strip()
|
||||||
|
start_s, end_s = parts.split('-')
|
||||||
|
start = int(start_s, 16)
|
||||||
|
end = int(end_s, 16)
|
||||||
|
egm_bytes += (end - start + 1)
|
||||||
|
except (PermissionError, FileNotFoundError, ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
egm_gb = egm_bytes / (1024**3)
|
||||||
|
if egm_gb > 0:
|
||||||
|
# EGM detected via iomem, add HBM
|
||||||
|
import torch
|
||||||
|
_, hbm_total = torch.cuda.mem_get_info(0)
|
||||||
|
hbm_gb = hbm_total / (1024**3)
|
||||||
|
return egm_gb + hbm_gb
|
||||||
|
|
||||||
|
# No EGM detected - return 0 (caller should handle)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def swap_allocator():
|
||||||
|
"""
|
||||||
|
Replace PyTorch's default CUDA allocator with our managed memory allocator.
|
||||||
|
This MUST happen before any CUDA tensors are created.
|
||||||
|
If sitecustomize.py already swapped it, this is a no-op.
|
||||||
|
"""
|
||||||
|
lib_path = os.environ.get(
|
||||||
|
'MANAGED_ALLOC_LIB',
|
||||||
|
'/usr/local/lib/libmanaged_alloc.so'
|
||||||
|
)
|
||||||
|
|
||||||
|
if not os.path.exists(lib_path):
|
||||||
|
print(f"[managed_mem] ERROR: {lib_path} not found!", file=sys.stderr)
|
||||||
|
print(f"[managed_mem] Build it with: nvcc -shared -o {lib_path} "
|
||||||
|
f"managed_alloc.cu -Xcompiler -fPIC", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Verify the library loads
|
||||||
|
try:
|
||||||
|
lib = ctypes.CDLL(lib_path)
|
||||||
|
assert hasattr(lib, 'managed_malloc'), "managed_malloc symbol not found"
|
||||||
|
assert hasattr(lib, 'managed_free'), "managed_free symbol not found"
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[managed_mem] ERROR loading {lib_path}: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
try:
|
||||||
|
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||||
|
lib_path, 'managed_malloc', 'managed_free'
|
||||||
|
)
|
||||||
|
torch.cuda.memory.change_current_allocator(alloc)
|
||||||
|
print(f"[managed_mem] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
||||||
|
except RuntimeError as e:
|
||||||
|
if "already initialized" in str(e):
|
||||||
|
print(f"[managed_mem] Allocator already initialized (sitecustomize)",
|
||||||
|
file=sys.stderr)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def patch_memory_snapshot():
|
||||||
|
"""
|
||||||
|
Patch MemorySnapshot.measure() to report the full managed memory
|
||||||
|
instead of just HBM.
|
||||||
|
|
||||||
|
This is the core fix: cudaMemGetInfo only reports HBM (~96 GiB),
|
||||||
|
but with cudaMallocManaged we can address HBM + EGM (~474 GiB).
|
||||||
|
We override the snapshot so all downstream code (request_memory,
|
||||||
|
KV cache sizing, etc.) sees the full managed memory capacity.
|
||||||
|
"""
|
||||||
|
from vllm.utils.mem_utils import MemorySnapshot
|
||||||
|
|
||||||
|
_original_measure = MemorySnapshot.measure
|
||||||
|
|
||||||
|
def patched_measure(self):
|
||||||
|
_original_measure(self)
|
||||||
|
# Override total_memory with managed memory capacity
|
||||||
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||||
|
if managed_total_gb:
|
||||||
|
managed_total = int(float(managed_total_gb) * (1024**3))
|
||||||
|
self.total_memory = managed_total
|
||||||
|
# free_memory = total - what CUDA is using
|
||||||
|
self.free_memory = managed_total - self.cuda_memory
|
||||||
|
print(f"[managed_mem] MemorySnapshot patched: "
|
||||||
|
f"total={self.total_memory / (1024**3):.0f} GiB, "
|
||||||
|
f"free={self.free_memory / (1024**3):.0f} GiB",
|
||||||
|
file=sys.stderr)
|
||||||
|
|
||||||
|
MemorySnapshot.measure = patched_measure
|
||||||
|
print(f"[managed_mem] Patched MemorySnapshot.measure()", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_torch_memory_tracking():
|
||||||
|
"""
|
||||||
|
Patch PyTorch memory tracking functions that CUDAPluggableAllocator
|
||||||
|
doesn't support. If sitecustomize already patched them, this is a no-op.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# Check if already patched by sitecustomize
|
||||||
|
if torch.cuda.reset_peak_memory_stats.__name__ == '_patched_reset_peak_memory_stats':
|
||||||
|
print(f"[managed_mem] torch.cuda already patched (sitecustomize)", file=sys.stderr)
|
||||||
|
return
|
||||||
|
|
||||||
|
_original_reset_peak = torch.cuda.reset_peak_memory_stats
|
||||||
|
_original_memory_stats = torch.cuda.memory_stats
|
||||||
|
_original_max_memory_allocated = torch.cuda.max_memory_allocated
|
||||||
|
|
||||||
|
def _patched_reset_peak_memory_stats(device=None):
|
||||||
|
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _patched_memory_stats(device=None):
|
||||||
|
"""Return minimal stats dict to avoid crashes."""
|
||||||
|
try:
|
||||||
|
return _original_memory_stats(device)
|
||||||
|
except RuntimeError:
|
||||||
|
return {
|
||||||
|
"allocated_bytes.all.current": 0,
|
||||||
|
"allocated_bytes.all.peak": 0,
|
||||||
|
"reserved_bytes.all.current": 0,
|
||||||
|
"reserved_bytes.all.peak": 0,
|
||||||
|
"num_alloc_retries": 0,
|
||||||
|
"num_ooms": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _patched_max_memory_allocated(device=None):
|
||||||
|
"""Return 0 since we can't track peak with pluggable allocator."""
|
||||||
|
try:
|
||||||
|
return _original_max_memory_allocated(device)
|
||||||
|
except RuntimeError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||||
|
torch.cuda.memory_stats = _patched_memory_stats
|
||||||
|
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
|
||||||
|
|
||||||
|
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
|
||||||
|
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||||
|
if hasattr(torch.accelerator, 'memory_stats'):
|
||||||
|
torch.accelerator.memory_stats = _patched_memory_stats
|
||||||
|
if hasattr(torch.accelerator, 'max_memory_allocated'):
|
||||||
|
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
|
||||||
|
|
||||||
|
print(f"[managed_mem] Patched torch.cuda memory tracking (no-op stubs)",
|
||||||
|
file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_vllm_memory_check():
|
||||||
|
"""
|
||||||
|
Monkey-patch vLLM's memory validation to understand managed memory.
|
||||||
|
|
||||||
|
With managed memory, cudaMemGetInfo only reports HBM, so the free-memory
|
||||||
|
check would fail. When MANAGED_MEMORY_TOTAL_GB is set, we skip that check.
|
||||||
|
"""
|
||||||
|
import vllm.v1.worker.utils as worker_utils
|
||||||
|
|
||||||
|
_original_request_memory = worker_utils.request_memory
|
||||||
|
|
||||||
|
def patched_request_memory(init_snapshot, cache_config):
|
||||||
|
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||||
|
if managed_total_gb:
|
||||||
|
total_bytes = float(managed_total_gb) * (1024**3)
|
||||||
|
gpu_util = cache_config.gpu_memory_utilization
|
||||||
|
requested = int(total_bytes * gpu_util)
|
||||||
|
print(f"[managed_mem] Overriding memory request: "
|
||||||
|
f"{float(managed_total_gb):.0f} GiB × {gpu_util} = "
|
||||||
|
f"{requested / (1024**3):.1f} GiB", file=sys.stderr)
|
||||||
|
return requested
|
||||||
|
else:
|
||||||
|
return _original_request_memory(init_snapshot, cache_config)
|
||||||
|
|
||||||
|
worker_utils.request_memory = patched_request_memory
|
||||||
|
print(f"[managed_mem] Patched vLLM request_memory", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Step 1: NO global allocator swap — model weights stay in HBM.
|
||||||
|
# KV cache uses cudaMallocManaged directly via
|
||||||
|
# VLLM_KV_CACHE_USE_MANAGED_MEMORY env var (set by sitecustomize.py).
|
||||||
|
# The global allocator swap broke cuBLAS GEMM operations because
|
||||||
|
# intermediate compute tensors ended up in managed memory.
|
||||||
|
print(f"[managed_mem] Using targeted KV cache managed allocation "
|
||||||
|
f"(no global allocator swap)", file=sys.stderr)
|
||||||
|
|
||||||
|
# Step 2: Calculate total managed memory and export it
|
||||||
|
total_managed_gb = get_total_managed_memory_gb()
|
||||||
|
if total_managed_gb <= 0:
|
||||||
|
print(f"[managed_mem] WARNING: No managed memory detected! "
|
||||||
|
f"Set MANAGED_MEMORY_TOTAL_GB env var explicitly.",
|
||||||
|
file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
os.environ['MANAGED_MEMORY_TOTAL_GB'] = str(total_managed_gb)
|
||||||
|
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
|
||||||
|
file=sys.stderr)
|
||||||
|
|
||||||
|
# Step 3: No torch.cuda memory tracking patches needed —
|
||||||
|
# we're not using CUDAPluggableAllocator anymore.
|
||||||
|
|
||||||
|
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
|
||||||
|
# This is critical - without it, all downstream code only sees HBM
|
||||||
|
patch_memory_snapshot()
|
||||||
|
|
||||||
|
# Step 5: Patch request_memory as a safety net
|
||||||
|
patch_vllm_memory_check()
|
||||||
|
|
||||||
|
# Step 6: Launch vLLM's API server with remaining args
|
||||||
|
sys.argv = ['vllm.entrypoints.openai.api_server'] + sys.argv[1:]
|
||||||
|
print(f"[managed_mem] Launching vLLM with args: {sys.argv[1:]}",
|
||||||
|
file=sys.stderr)
|
||||||
|
|
||||||
|
# Import and run
|
||||||
|
from vllm.entrypoints.openai.api_server import run_server, FlexibleArgumentParser
|
||||||
|
import uvloop
|
||||||
|
|
||||||
|
parser = FlexibleArgumentParser(
|
||||||
|
description="vLLM OpenAI-compatible API server (managed memory)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use vLLM's own argument parser
|
||||||
|
from vllm.entrypoints.openai.api_server import make_arg_parser
|
||||||
|
parser = make_arg_parser(parser)
|
||||||
|
args = parser.parse_args(sys.argv[1:])
|
||||||
|
|
||||||
|
uvloop.run(run_server(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user