Model weights (small tensors) must be in HBM for cuBLAS GEMM ops which can't page-fault into managed memory. KV cache blocks are large and numerous — prefetching them all fills HBM and causes OOM. The 2 GiB threshold separates compute data from cache data.
97 lines
4.1 KiB
Plaintext
97 lines
4.1 KiB
Plaintext
// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
|
|
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
|
|
// Compatible with CUDA 13+ (uses cudaMemLocation API)
|
|
//
|
|
// Key design decisions for GH200 EGM:
|
|
// 1. cudaMallocManaged → allocations can page-fault across HBM + EGM
|
|
// 2. cudaMemAdviseSetPreferredLocation(GPU) → driver prefers keeping pages on GPU
|
|
// 3. cudaMemAdviseSetAccessedBy(CPU) → CPU can access over C2C NVLink without
|
|
// triggering page migration back to system RAM (critical: prevents OOM)
|
|
// 4. Selective prefetching — small allocations (model weights, <2 GiB)
|
|
// are prefetched to GPU so cuBLAS/cuDNN kernels can access them
|
|
// directly from HBM. Large allocations (KV cache blocks) stay in
|
|
// managed memory and page-fault on demand, since they're too large
|
|
// to fit in HBM and attention ops can tolerate page faults.
|
|
#include <cuda_runtime.h>
|
|
#include <stdio.h>
|
|
|
|
extern "C" {
|
|
|
|
// PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t)
|
|
void* managed_malloc(size_t size, int device, cudaStream_t stream) {
|
|
void* ptr = nullptr;
|
|
|
|
// Set the device before allocating
|
|
cudaError_t err = cudaSetDevice(device);
|
|
if (err != cudaSuccess) {
|
|
fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n",
|
|
device, cudaGetErrorString(err));
|
|
return nullptr;
|
|
}
|
|
|
|
// Use cudaMallocManaged - this is the key: allocations can page-fault
|
|
// across HBM and LPDDR on GH200 with EGM enabled
|
|
err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal);
|
|
if (err != cudaSuccess) {
|
|
fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s "
|
|
"(size=%zu bytes / %.2f GiB)\n",
|
|
cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0));
|
|
return nullptr;
|
|
}
|
|
|
|
// CUDA 13+ uses cudaMemLocation struct instead of int for device
|
|
cudaMemLocation gpu_loc;
|
|
gpu_loc.type = cudaMemLocationTypeDevice;
|
|
gpu_loc.id = device;
|
|
|
|
// Advise: prefer GPU placement. On GH200 with EGM, the hardware will
|
|
// migrate pages as needed, but the driver tries to keep them on GPU.
|
|
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, gpu_loc);
|
|
|
|
// Advise: CPU will access this memory too. On GH200, this sets up
|
|
// remote mapping over C2C NVLink so CPU can read/write without
|
|
// triggering page migration back to system RAM. This is CRITICAL
|
|
// to prevent OOM on EGM systems where most system RAM was carved
|
|
// out for the GPU.
|
|
cudaMemLocation cpu_loc;
|
|
cpu_loc.type = cudaMemLocationTypeHost;
|
|
cpu_loc.id = cudaCpuDeviceId;
|
|
cudaMemAdvise(ptr, size, cudaMemAdviseSetAccessedBy, cpu_loc);
|
|
|
|
// Selective prefetch: migrate pages to GPU for small allocations only.
|
|
// Model weights (individual tensors) are typically <2 GiB and MUST be
|
|
// on GPU for cuBLAS GEMM operations — GPU compute kernels cannot
|
|
// page-fault into managed memory during execution.
|
|
// KV cache blocks are large and numerous; prefetching them all fills
|
|
// HBM and causes subsequent allocations to fail.
|
|
// The 2 GiB threshold separates "compute data" from "cache data".
|
|
const size_t PREFETCH_THRESHOLD = 2ULL * 1024 * 1024 * 1024; // 2 GiB
|
|
|
|
if (size > 0 && size < PREFETCH_THRESHOLD) {
|
|
err = cudaMemPrefetchAsync(ptr, size, gpu_loc, 0);
|
|
if (err != cudaSuccess) {
|
|
// Non-fatal: prefetch failure shouldn't prevent allocation.
|
|
// Pages will still be migrated on demand.
|
|
fprintf(stderr, "[managed_alloc] cudaMemPrefetchAsync warning: %s "
|
|
"(size=%.2f GiB, will use on-demand migration)\n",
|
|
cudaGetErrorString(err), (double)size / (1024.0*1024.0*1024.0));
|
|
}
|
|
}
|
|
|
|
return ptr;
|
|
}
|
|
|
|
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
|
|
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
|
|
if (ptr != nullptr) {
|
|
// Sync the stream before freeing to avoid use-after-free with
|
|
// managed memory (in-flight page faults can race with deallocation).
|
|
if (stream != nullptr) {
|
|
cudaStreamSynchronize(stream);
|
|
}
|
|
cudaFree(ptr);
|
|
}
|
|
}
|
|
|
|
} // extern "C"
|