// managed_alloc.cu - cudaMallocManaged allocator for PyTorch // Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC // Compatible with CUDA 13+ (uses cudaMemLocation API) #include #include extern "C" { // PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t) void* managed_malloc(size_t size, int device, cudaStream_t stream) { void* ptr = nullptr; // Set the device before allocating cudaError_t err = cudaSetDevice(device); if (err != cudaSuccess) { fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n", device, cudaGetErrorString(err)); return nullptr; } // Use cudaMallocManaged - this is the key: allocations can page-fault // across HBM and LPDDR on GH200 with EGM enabled err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal); if (err != cudaSuccess) { fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s " "(size=%zu bytes / %.2f GiB)\n", cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0)); return nullptr; } // Advise the driver to prefer GPU placement initially. // On GH200 with EGM, the hardware will migrate pages as needed. // CUDA 13+ uses cudaMemLocation struct instead of int for device cudaMemLocation location; location.type = cudaMemLocationTypeDevice; location.id = device; cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, location); return ptr; } // PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t) void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) { if (ptr != nullptr) { // Sync the stream before freeing to avoid use-after-free if (stream != nullptr) { cudaStreamSynchronize(stream); } cudaFree(ptr); } } } // extern "C"