// managed_alloc.cu - cudaMallocManaged allocator for PyTorch // Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC // Compatible with CUDA 13+ (uses cudaMemLocation API) // // Key design decisions for GH200 EGM: // 1. cudaMallocManaged → allocations can page-fault across HBM + EGM // 2. cudaMemAdviseSetPreferredLocation(GPU) → driver prefers keeping pages on GPU // 3. cudaMemAdviseSetAccessedBy(CPU) → CPU can access over C2C NVLink without // triggering page migration back to system RAM (critical: prevents OOM) // 4. NO prefetching — pages migrate on-demand via hardware page faults. // Eager prefetching fills HBM+EGM and causes subsequent allocations // to fail. On-demand migration is the correct behavior for unified // memory with HBM + LPDDR EGM. #include #include extern "C" { // PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t) void* managed_malloc(size_t size, int device, cudaStream_t stream) { void* ptr = nullptr; // Set the device before allocating cudaError_t err = cudaSetDevice(device); if (err != cudaSuccess) { fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n", device, cudaGetErrorString(err)); return nullptr; } // Use cudaMallocManaged - this is the key: allocations can page-fault // across HBM and LPDDR on GH200 with EGM enabled err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal); if (err != cudaSuccess) { fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s " "(size=%zu bytes / %.2f GiB)\n", cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0)); return nullptr; } // CUDA 13+ uses cudaMemLocation struct instead of int for device cudaMemLocation gpu_loc; gpu_loc.type = cudaMemLocationTypeDevice; gpu_loc.id = device; // Advise: prefer GPU placement. On GH200 with EGM, the hardware will // migrate pages as needed, but the driver tries to keep them on GPU. cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, gpu_loc); // Advise: CPU will access this memory too. On GH200, this sets up // remote mapping over C2C NVLink so CPU can read/write without // triggering page migration back to system RAM. This is CRITICAL // to prevent OOM on EGM systems where most system RAM was carved // out for the GPU. cudaMemLocation cpu_loc; cpu_loc.type = cudaMemLocationTypeHost; cpu_loc.id = cudaCpuDeviceId; cudaMemAdvise(ptr, size, cudaMemAdviseSetAccessedBy, cpu_loc); // REMOVED: cudaMemPrefetchAsync — was causing allocation failures after // model loading. Prefetching eagerly migrates ALL pages to GPU, filling // up HBM+EGM. Once physical memory is consumed by prefetched pages, the // next cudaMallocManaged call fails because the driver can't guarantee // page-fault resolution for new allocations. // // On GH200 with EGM, the hardware handles page faults naturally via C2C // NVLink. The cudaMemAdviseSetPreferredLocation(GPU) hint above tells // the driver to prefer GPU placement, but allows fallback to LPDDR when // HBM is full. That's exactly what we want — don't force it. // // Pages will migrate on-demand as they're accessed, which is the correct // behavior for a unified memory system with 96 GiB HBM + 128+ GiB EGM. return ptr; } // PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t) void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) { if (ptr != nullptr) { // Sync the stream before freeing to avoid use-after-free with // managed memory (in-flight page faults can race with deallocation). if (stream != nullptr) { cudaStreamSynchronize(stream); } cudaFree(ptr); } } } // extern "C"