diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 29303ace5..45bb7da5e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -221,9 +221,19 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor: LPDDR, enabling much larger KV caches than HBM alone would allow. Model weights and compute intermediates remain in HBM via the default cudaMalloc — only the KV cache uses managed memory. - """ - import struct + Key design decisions for KV cache (different from model weights): + - Preferred location is CPU (LPDDR), NOT GPU. The KV cache is too + large to fit in HBM (often 50-100+ GiB). Setting preferred location + to GPU would cause the driver to try migrating the entire allocation + to HBM, resulting in OOM. Instead, let pages reside in LPDDR and + page-fault to GPU on-demand during attention operations. + - CPU access is set via AccessedBy to avoid page migration back to + system RAM on CPU reads (GH200 uses C2C NVLink for remote access). + - Zeroing is done via CPU memset, NOT cudaMemset. cudaMemset executes + on the device, which forces all pages to be migrated to GPU first — + exactly what we're trying to avoid. CPU memset leaves pages in LPDDR. + """ cuda = ctypes.CDLL("libcudart.so") # Allocate managed memory @@ -239,21 +249,47 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor: dev_idx = device.index if device.index is not None else 0 - # cudaMemAdvise: prefer GPU placement (cudaMemAdviseSetPreferredLocation=3) - # cudaMemLocation struct: {type=0(device), id=dev_idx} - gpu_loc = (ctypes.c_int * 2)(0, dev_idx) - cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size), ctypes.c_int(3), gpu_loc) - - # cudaMemAdvise: CPU access without migration (cudaMemAdviseSetAccessedBy=9) + # cudaMemAdvise: prefer CPU placement (cudaMemAdviseSetPreferredLocation=3). + # Unlike model weights (which are small and need to be on GPU for GEMM), + # the KV cache is far too large for HBM. Setting preferred location to + # CPU keeps the pages in LPDDR/EGM. They will page-fault to GPU only + # when the attention kernel accesses them, and the driver will evict + # them back to LPDDR when HBM is needed for other allocations. # cudaMemLocation struct: {type=1(host), id=-1(cudaCpuDeviceId)} cpu_loc = (ctypes.c_int * 2)(1, -1) - cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size), ctypes.c_int(9), cpu_loc) + advise_err = cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size), + ctypes.c_int(3), cpu_loc) + if advise_err != 0: + logger.warning("cudaMemAdvise SetPreferredLocation(CPU) failed " + "(err=%d), KV cache pages may not stay in LPDDR", + advise_err) - # Zero out the managed memory (memset on device) - cuda.cudaMemset(ptr, ctypes.c_int(0), ctypes.c_size_t(size)) + # cudaMemAdvise: GPU access without migration (cudaMemAdviseSetAccessedBy=9). + # This tells the driver that the GPU will read these pages, but should + # NOT migrate them to GPU on access — instead, the GPU accesses them + # remotely over C2C NVLink. This prevents the KV cache from evicting + # model weights and compute intermediates from HBM. + # cudaMemLocation struct: {type=0(device), id=dev_idx} + gpu_loc = (ctypes.c_int * 2)(0, dev_idx) + advise_err = cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size), + ctypes.c_int(9), gpu_loc) + if advise_err != 0: + logger.warning("cudaMemAdvise SetAccessedBy(GPU) failed " + "(err=%d), GPU may trigger page migrations", advise_err) - # Wrap as a PyTorch tensor using UntypedStorage.from_blob - # This creates a storage that points to our managed memory + # Zero out the managed memory via CPU, NOT via cudaMemset. + # cudaMemset runs on the device, which forces ALL pages to be migrated + # to GPU before zeroing — defeating the entire purpose of keeping the + # KV cache in LPDDR. Using ctypes.memset (CPU) zeroes the pages while + # they remain in LPDDR/EGM. The pages will be lazily migrated to GPU + # only when the attention kernel actually reads them. + ctypes.memset(ptr, 0, size) + + # Wrap as a PyTorch tensor using UntypedStorage.from_blob. + # We use from_blob with a no-op allocator so PyTorch doesn't try to + # free the managed memory through its own allocator — we manage the + # lifetime ourselves (cudaFree happens when the tensor is garbage + # collected via the normal reference counting path). storage = torch.UntypedStorage.from_blob( ptr, size_bytes=size, allocator=lambda: None # no-op allocator, we manage lifetime @@ -261,7 +297,8 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor: tensor = torch.tensor([], dtype=torch.int8, device=device) tensor.set_(storage, 0, [size]) - logger.info("Allocated KV cache via cudaMallocManaged: %.2f GiB", + logger.info("Allocated KV cache via cudaMallocManaged: %.2f GiB " + "(preferred=CPU, on-demand page-fault to GPU)", size / 1024**3) return tensor