- managed_alloc.cu: PyTorch pluggable allocator using cudaMallocManaged - vllm_managed_mem.py: Launcher that patches vLLM for managed memory - Dockerfile: Build and install managed memory components This enables vLLM to use cudaMallocManaged for transparent page-fault access to both HBM (~96 GiB) and LPDDR (EGM, up to 480 GiB additional) on GH200 systems with Extended GPU Memory enabled. Experimental branch: v0.19.0-cmm
49 lines
1.6 KiB
Plaintext
49 lines
1.6 KiB
Plaintext
// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
|
|
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
|
|
#include <cuda_runtime.h>
|
|
#include <stdio.h>
|
|
|
|
extern "C" {
|
|
|
|
// PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t)
|
|
void* managed_malloc(size_t size, int device, cudaStream_t stream) {
|
|
void* ptr = nullptr;
|
|
|
|
// Set the device before allocating
|
|
cudaError_t err = cudaSetDevice(device);
|
|
if (err != cudaSuccess) {
|
|
fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n",
|
|
device, cudaGetErrorString(err));
|
|
return nullptr;
|
|
}
|
|
|
|
// Use cudaMallocManaged - this is the key: allocations can page-fault
|
|
// across HBM and LPDDR on GH200 with EGM enabled
|
|
err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal);
|
|
if (err != cudaSuccess) {
|
|
fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s "
|
|
"(size=%zu bytes / %.2f GiB)\n",
|
|
cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0));
|
|
return nullptr;
|
|
}
|
|
|
|
// Advise the driver to prefer GPU placement initially.
|
|
// On GH200 with EGM, the hardware will migrate pages as needed.
|
|
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, device);
|
|
|
|
return ptr;
|
|
}
|
|
|
|
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
|
|
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
|
|
if (ptr != nullptr) {
|
|
// Sync the stream before freeing to avoid use-after-free
|
|
if (stream != nullptr) {
|
|
cudaStreamSynchronize(stream);
|
|
}
|
|
cudaFree(ptr);
|
|
}
|
|
}
|
|
|
|
} // extern "C"
|