Files
vllm/managed_alloc.cu
biondizzle 487dd34e04 Selective prefetch: only prefetch allocations <2 GiB to GPU
Model weights (small tensors) must be in HBM for cuBLAS GEMM ops
which can't page-fault into managed memory. KV cache blocks are
large and numerous — prefetching them all fills HBM and causes
OOM. The 2 GiB threshold separates compute data from cache data.
2026-04-10 14:58:57 +00:00

97 lines
4.1 KiB
Plaintext

// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
// Compatible with CUDA 13+ (uses cudaMemLocation API)
//
// Key design decisions for GH200 EGM:
// 1. cudaMallocManaged → allocations can page-fault across HBM + EGM
// 2. cudaMemAdviseSetPreferredLocation(GPU) → driver prefers keeping pages on GPU
// 3. cudaMemAdviseSetAccessedBy(CPU) → CPU can access over C2C NVLink without
// triggering page migration back to system RAM (critical: prevents OOM)
// 4. Selective prefetching — small allocations (model weights, <2 GiB)
// are prefetched to GPU so cuBLAS/cuDNN kernels can access them
// directly from HBM. Large allocations (KV cache blocks) stay in
// managed memory and page-fault on demand, since they're too large
// to fit in HBM and attention ops can tolerate page faults.
#include <cuda_runtime.h>
#include <stdio.h>
extern "C" {
// PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t)
void* managed_malloc(size_t size, int device, cudaStream_t stream) {
void* ptr = nullptr;
// Set the device before allocating
cudaError_t err = cudaSetDevice(device);
if (err != cudaSuccess) {
fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n",
device, cudaGetErrorString(err));
return nullptr;
}
// Use cudaMallocManaged - this is the key: allocations can page-fault
// across HBM and LPDDR on GH200 with EGM enabled
err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal);
if (err != cudaSuccess) {
fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s "
"(size=%zu bytes / %.2f GiB)\n",
cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0));
return nullptr;
}
// CUDA 13+ uses cudaMemLocation struct instead of int for device
cudaMemLocation gpu_loc;
gpu_loc.type = cudaMemLocationTypeDevice;
gpu_loc.id = device;
// Advise: prefer GPU placement. On GH200 with EGM, the hardware will
// migrate pages as needed, but the driver tries to keep them on GPU.
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, gpu_loc);
// Advise: CPU will access this memory too. On GH200, this sets up
// remote mapping over C2C NVLink so CPU can read/write without
// triggering page migration back to system RAM. This is CRITICAL
// to prevent OOM on EGM systems where most system RAM was carved
// out for the GPU.
cudaMemLocation cpu_loc;
cpu_loc.type = cudaMemLocationTypeHost;
cpu_loc.id = cudaCpuDeviceId;
cudaMemAdvise(ptr, size, cudaMemAdviseSetAccessedBy, cpu_loc);
// Selective prefetch: migrate pages to GPU for small allocations only.
// Model weights (individual tensors) are typically <2 GiB and MUST be
// on GPU for cuBLAS GEMM operations — GPU compute kernels cannot
// page-fault into managed memory during execution.
// KV cache blocks are large and numerous; prefetching them all fills
// HBM and causes subsequent allocations to fail.
// The 2 GiB threshold separates "compute data" from "cache data".
const size_t PREFETCH_THRESHOLD = 2ULL * 1024 * 1024 * 1024; // 2 GiB
if (size > 0 && size < PREFETCH_THRESHOLD) {
err = cudaMemPrefetchAsync(ptr, size, gpu_loc, 0);
if (err != cudaSuccess) {
// Non-fatal: prefetch failure shouldn't prevent allocation.
// Pages will still be migrated on demand.
fprintf(stderr, "[managed_alloc] cudaMemPrefetchAsync warning: %s "
"(size=%.2f GiB, will use on-demand migration)\n",
cudaGetErrorString(err), (double)size / (1024.0*1024.0*1024.0));
}
}
return ptr;
}
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
if (ptr != nullptr) {
// Sync the stream before freeing to avoid use-after-free with
// managed memory (in-flight page faults can race with deallocation).
if (stream != nullptr) {
cudaStreamSynchronize(stream);
}
cudaFree(ptr);
}
}
} // extern "C"