Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 013b73e9b2 | |||
| c77342da87 | |||
| 7f35bc4158 | |||
| 487dd34e04 | |||
| a15f86ecfa |
96
managed_alloc.cu
Normal file
96
managed_alloc.cu
Normal file
@@ -0,0 +1,96 @@
|
||||
// managed_alloc.cu - cudaMallocManaged allocator for PyTorch
|
||||
// Compile: nvcc -shared -o libmanaged_alloc.so managed_alloc.cu -Xcompiler -fPIC
|
||||
// Compatible with CUDA 13+ (uses cudaMemLocation API)
|
||||
//
|
||||
// Key design decisions for GH200 EGM:
|
||||
// 1. cudaMallocManaged → allocations can page-fault across HBM + EGM
|
||||
// 2. cudaMemAdviseSetPreferredLocation(GPU) → driver prefers keeping pages on GPU
|
||||
// 3. cudaMemAdviseSetAccessedBy(CPU) → CPU can access over C2C NVLink without
|
||||
// triggering page migration back to system RAM (critical: prevents OOM)
|
||||
// 4. Selective prefetching — small allocations (model weights, <2 GiB)
|
||||
// are prefetched to GPU so cuBLAS/cuDNN kernels can access them
|
||||
// directly from HBM. Large allocations (KV cache blocks) stay in
|
||||
// managed memory and page-fault on demand, since they're too large
|
||||
// to fit in HBM and attention ops can tolerate page faults.
|
||||
#include <cuda_runtime.h>
|
||||
#include <stdio.h>
|
||||
|
||||
extern "C" {
|
||||
|
||||
// PyTorch pluggable allocator signature: void*(size_t, int, cudaStream_t)
|
||||
void* managed_malloc(size_t size, int device, cudaStream_t stream) {
|
||||
void* ptr = nullptr;
|
||||
|
||||
// Set the device before allocating
|
||||
cudaError_t err = cudaSetDevice(device);
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "[managed_alloc] cudaSetDevice(%d) failed: %s\n",
|
||||
device, cudaGetErrorString(err));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Use cudaMallocManaged - this is the key: allocations can page-fault
|
||||
// across HBM and LPDDR on GH200 with EGM enabled
|
||||
err = cudaMallocManaged(&ptr, size, cudaMemAttachGlobal);
|
||||
if (err != cudaSuccess) {
|
||||
fprintf(stderr, "[managed_alloc] cudaMallocManaged failed: %s "
|
||||
"(size=%zu bytes / %.2f GiB)\n",
|
||||
cudaGetErrorString(err), size, (double)size / (1024.0*1024.0*1024.0));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// CUDA 13+ uses cudaMemLocation struct instead of int for device
|
||||
cudaMemLocation gpu_loc;
|
||||
gpu_loc.type = cudaMemLocationTypeDevice;
|
||||
gpu_loc.id = device;
|
||||
|
||||
// Advise: prefer GPU placement. On GH200 with EGM, the hardware will
|
||||
// migrate pages as needed, but the driver tries to keep them on GPU.
|
||||
cudaMemAdvise(ptr, size, cudaMemAdviseSetPreferredLocation, gpu_loc);
|
||||
|
||||
// Advise: CPU will access this memory too. On GH200, this sets up
|
||||
// remote mapping over C2C NVLink so CPU can read/write without
|
||||
// triggering page migration back to system RAM. This is CRITICAL
|
||||
// to prevent OOM on EGM systems where most system RAM was carved
|
||||
// out for the GPU.
|
||||
cudaMemLocation cpu_loc;
|
||||
cpu_loc.type = cudaMemLocationTypeHost;
|
||||
cpu_loc.id = cudaCpuDeviceId;
|
||||
cudaMemAdvise(ptr, size, cudaMemAdviseSetAccessedBy, cpu_loc);
|
||||
|
||||
// Selective prefetch: migrate pages to GPU for small allocations only.
|
||||
// Model weights (individual tensors) are typically <2 GiB and MUST be
|
||||
// on GPU for cuBLAS GEMM operations — GPU compute kernels cannot
|
||||
// page-fault into managed memory during execution.
|
||||
// KV cache blocks are large and numerous; prefetching them all fills
|
||||
// HBM and causes subsequent allocations to fail.
|
||||
// The 2 GiB threshold separates "compute data" from "cache data".
|
||||
const size_t PREFETCH_THRESHOLD = 2ULL * 1024 * 1024 * 1024; // 2 GiB
|
||||
|
||||
if (size > 0 && size < PREFETCH_THRESHOLD) {
|
||||
err = cudaMemPrefetchAsync(ptr, size, gpu_loc, 0);
|
||||
if (err != cudaSuccess) {
|
||||
// Non-fatal: prefetch failure shouldn't prevent allocation.
|
||||
// Pages will still be migrated on demand.
|
||||
fprintf(stderr, "[managed_alloc] cudaMemPrefetchAsync warning: %s "
|
||||
"(size=%.2f GiB, will use on-demand migration)\n",
|
||||
cudaGetErrorString(err), (double)size / (1024.0*1024.0*1024.0));
|
||||
}
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
// PyTorch pluggable allocator signature: void(void*, size_t, int, cudaStream_t)
|
||||
void managed_free(void* ptr, size_t size, int device, cudaStream_t stream) {
|
||||
if (ptr != nullptr) {
|
||||
// Sync the stream before freeing to avoid use-after-free with
|
||||
// managed memory (in-flight page faults can race with deallocation).
|
||||
if (stream != nullptr) {
|
||||
cudaStreamSynchronize(stream);
|
||||
}
|
||||
cudaFree(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
@@ -185,6 +185,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
|
||||
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
|
||||
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
|
||||
VLLM_KV_CACHE_USE_MANAGED_MEMORY: bool = False
|
||||
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
|
||||
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
|
||||
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[
|
||||
@@ -1378,6 +1379,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_KV_CACHE_LAYOUT": env_with_choices(
|
||||
"VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]
|
||||
),
|
||||
# On GH200 with EGM, allocate KV cache via cudaMallocManaged so it
|
||||
# can page-fault into LPDDR memory, enabling larger KV caches.
|
||||
"VLLM_KV_CACHE_USE_MANAGED_MEMORY": lambda: bool(
|
||||
int(os.getenv("VLLM_KV_CACHE_USE_MANAGED_MEMORY", "0"))
|
||||
),
|
||||
# Enable checking whether the generated logits contain NaNs,
|
||||
# indicating corrupted output. Useful for debugging low level bugs
|
||||
# or bad hardware but it may add compute overhead.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
import functools
|
||||
import gc
|
||||
import itertools
|
||||
@@ -212,6 +213,81 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
|
||||
"""Allocate KV cache memory using cudaMallocManaged for GH200 EGM.
|
||||
|
||||
This allows the KV cache to transparently page-fault between HBM and
|
||||
LPDDR, enabling much larger KV caches than HBM alone would allow.
|
||||
Model weights and compute intermediates remain in HBM via the default
|
||||
cudaMalloc — only the KV cache uses managed memory.
|
||||
|
||||
Key design decisions for KV cache (different from model weights):
|
||||
- No cudaMemAdviseSetPreferredLocation(GPU). The KV cache is too
|
||||
large for HBM (often 50-100+ GiB). Setting preferred location to
|
||||
GPU would cause the driver to try migrating the entire allocation
|
||||
to HBM, resulting in OOM. Pages page-fault to GPU on-demand during
|
||||
attention operations and are evicted back to LPDDR when HBM is
|
||||
needed.
|
||||
- Zeroing is done via CPU memset, NOT cudaMemset. cudaMemset executes
|
||||
on the device, which forces ALL pages to be migrated to GPU first —
|
||||
exactly what we're trying to avoid. CPU memset leaves pages in LPDDR.
|
||||
- Tensor wrapping uses __cuda_array_interface__ instead of the
|
||||
deprecated UntypedStorage.from_blob (removed in PyTorch 2.11+).
|
||||
"""
|
||||
cuda = ctypes.CDLL("libcudart.so")
|
||||
|
||||
# Allocate managed memory
|
||||
ptr = ctypes.c_void_p()
|
||||
err = cuda.cudaMallocManaged(ctypes.byref(ptr), ctypes.c_size_t(size),
|
||||
ctypes.c_uint(1)) # cudaMemAttachGlobal
|
||||
if err != 0:
|
||||
cuda.cudaGetErrorString.restype = ctypes.c_char_p
|
||||
err_str = cuda.cudaGetErrorString(err)
|
||||
raise RuntimeError(
|
||||
f"cudaMallocManaged failed for KV cache ({size} bytes / "
|
||||
f"{size / 1024**3:.2f} GiB): {err_str.decode()}")
|
||||
|
||||
# Zero out the managed memory via CPU, NOT via cudaMemset.
|
||||
# cudaMemset runs on the device, which forces ALL pages to be migrated
|
||||
# to GPU before zeroing — defeating the entire purpose of keeping the
|
||||
# KV cache in LPDDR. Using ctypes.memset (CPU) zeroes the pages while
|
||||
# they remain in LPDDR/EGM. The pages will be lazily migrated to GPU
|
||||
# only when the attention kernel actually reads them.
|
||||
ctypes.memset(ptr, 0, size)
|
||||
|
||||
# Wrap the managed memory pointer as a PyTorch tensor using
|
||||
# __cuda_array_interface__. This is the standard mechanism for creating
|
||||
# tensors from external CUDA pointers, and works across PyTorch versions
|
||||
# (UntypedStorage.from_blob was removed in PyTorch 2.11+).
|
||||
# The tensor does NOT own the memory — we handle cudaFree ourselves
|
||||
# when the KV cache is destroyed (process lifetime).
|
||||
class _ManagedMemoryWrapper:
|
||||
"""Wrapper exposing __cuda_array_interface__ for a cudaMallocManaged pointer."""
|
||||
def __init__(self, ptr_value, nbytes):
|
||||
self._ptr = ptr_value
|
||||
self._nbytes = nbytes
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
return {
|
||||
"data": (self._ptr, False), # (ptr, readonly)
|
||||
"shape": (self._nbytes,),
|
||||
"strides": None,
|
||||
"typestr": "|i1", # int8 bytes
|
||||
"version": 3,
|
||||
}
|
||||
|
||||
wrapper = _ManagedMemoryWrapper(ptr.value, size)
|
||||
tensor = torch.as_tensor(wrapper, device=device)
|
||||
|
||||
logger.info("Allocated KV cache via cudaMallocManaged: %.2f GiB "
|
||||
"(CPU memset, on-demand page-fault to GPU)",
|
||||
size / 1024**3)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
|
||||
# list when ubatching is enabled
|
||||
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
|
||||
@@ -6526,9 +6602,17 @@ class GPUModelRunner(
|
||||
"""
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(
|
||||
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
# On GH200 with EGM, allocate KV cache via cudaMallocManaged
|
||||
# so it can spill into LPDDR memory. This is controlled by
|
||||
# the VLLM_KV_CACHE_USE_MANAGED_MEMORY env var.
|
||||
if envs.VLLM_KV_CACHE_USE_MANAGED_MEMORY:
|
||||
tensor = _allocate_managed_kv_cache(
|
||||
kv_cache_tensor.size, self.device)
|
||||
else:
|
||||
tensor = torch.zeros(
|
||||
kv_cache_tensor.size, dtype=torch.int8,
|
||||
device=self.device
|
||||
)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
|
||||
Reference in New Issue
Block a user