Targeted KV cache managed memory allocation
Instead of swapping the global CUDA allocator (which broke cuBLAS), allocate KV cache via cudaMallocManaged directly in _allocate_kv_cache_tensors(). Controlled by VLLM_KV_CACHE_USE_MANAGED_MEMORY env var. Model weights and compute intermediates stay in HBM via default cudaMalloc. Only KV cache spills into EGM/LPDDR.
This commit is contained in:
@@ -185,6 +185,7 @@ if TYPE_CHECKING:
|
||||
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
|
||||
VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS: int = 300
|
||||
VLLM_KV_CACHE_LAYOUT: Literal["NHD", "HND"] | None = None
|
||||
VLLM_KV_CACHE_USE_MANAGED_MEMORY: bool = False
|
||||
VLLM_COMPUTE_NANS_IN_LOGITS: bool = False
|
||||
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
|
||||
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: Literal[
|
||||
@@ -1378,6 +1379,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_KV_CACHE_LAYOUT": env_with_choices(
|
||||
"VLLM_KV_CACHE_LAYOUT", None, ["NHD", "HND"]
|
||||
),
|
||||
# On GH200 with EGM, allocate KV cache via cudaMallocManaged so it
|
||||
# can page-fault into LPDDR memory, enabling larger KV caches.
|
||||
"VLLM_KV_CACHE_USE_MANAGED_MEMORY": lambda: bool(
|
||||
int(os.getenv("VLLM_KV_CACHE_USE_MANAGED_MEMORY", "0"))
|
||||
),
|
||||
# Enable checking whether the generated logits contain NaNs,
|
||||
# indicating corrupted output. Useful for debugging low level bugs
|
||||
# or bad hardware but it may add compute overhead.
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ctypes
|
||||
import functools
|
||||
import gc
|
||||
import itertools
|
||||
@@ -212,6 +213,60 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
|
||||
"""Allocate KV cache memory using cudaMallocManaged for GH200 EGM.
|
||||
|
||||
This allows the KV cache to transparently page-fault between HBM and
|
||||
LPDDR, enabling much larger KV caches than HBM alone would allow.
|
||||
Model weights and compute intermediates remain in HBM via the default
|
||||
cudaMalloc — only the KV cache uses managed memory.
|
||||
"""
|
||||
import struct
|
||||
|
||||
cuda = ctypes.CDLL("libcudart.so")
|
||||
|
||||
# Allocate managed memory
|
||||
ptr = ctypes.c_void_p()
|
||||
err = cuda.cudaMallocManaged(ctypes.byref(ptr), ctypes.c_size_t(size),
|
||||
ctypes.c_uint(1)) # cudaMemAttachGlobal
|
||||
if err != 0:
|
||||
cuda.cudaGetErrorString.restype = ctypes.c_char_p
|
||||
err_str = cuda.cudaGetErrorString(err)
|
||||
raise RuntimeError(
|
||||
f"cudaMallocManaged failed for KV cache ({size} bytes / "
|
||||
f"{size / 1024**3:.2f} GiB): {err_str.decode()}")
|
||||
|
||||
dev_idx = device.index if device.index is not None else 0
|
||||
|
||||
# cudaMemAdvise: prefer GPU placement (cudaMemAdviseSetPreferredLocation=3)
|
||||
# cudaMemLocation struct: {type=0(device), id=dev_idx}
|
||||
gpu_loc = (ctypes.c_int * 2)(0, dev_idx)
|
||||
cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size), ctypes.c_int(3), gpu_loc)
|
||||
|
||||
# cudaMemAdvise: CPU access without migration (cudaMemAdviseSetAccessedBy=9)
|
||||
# cudaMemLocation struct: {type=1(host), id=-1(cudaCpuDeviceId)}
|
||||
cpu_loc = (ctypes.c_int * 2)(1, -1)
|
||||
cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size), ctypes.c_int(9), cpu_loc)
|
||||
|
||||
# Zero out the managed memory (memset on device)
|
||||
cuda.cudaMemset(ptr, ctypes.c_int(0), ctypes.c_size_t(size))
|
||||
|
||||
# Wrap as a PyTorch tensor using UntypedStorage.from_blob
|
||||
# This creates a storage that points to our managed memory
|
||||
storage = torch.UntypedStorage.from_blob(
|
||||
ptr, size_bytes=size,
|
||||
allocator=lambda: None # no-op allocator, we manage lifetime
|
||||
)
|
||||
tensor = torch.tensor([], dtype=torch.int8, device=device)
|
||||
tensor.set_(storage, 0, [size])
|
||||
|
||||
logger.info("Allocated KV cache via cudaMallocManaged: %.2f GiB",
|
||||
size / 1024**3)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
AttnMetadataDict: TypeAlias = dict[str, AttentionMetadata]
|
||||
# list when ubatching is enabled
|
||||
PerLayerAttnMetadata: TypeAlias = list[AttnMetadataDict] | AttnMetadataDict
|
||||
@@ -6526,8 +6581,16 @@ class GPUModelRunner(
|
||||
"""
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
# On GH200 with EGM, allocate KV cache via cudaMallocManaged
|
||||
# so it can spill into LPDDR memory. This is controlled by
|
||||
# the VLLM_KV_CACHE_USE_MANAGED_MEMORY env var.
|
||||
if envs.VLLM_KV_CACHE_USE_MANAGED_MEMORY:
|
||||
tensor = _allocate_managed_kv_cache(
|
||||
kv_cache_tensor.size, self.device)
|
||||
else:
|
||||
tensor = torch.zeros(
|
||||
kv_cache_tensor.size, dtype=torch.int8, device=self.device
|
||||
kv_cache_tensor.size, dtype=torch.int8,
|
||||
device=self.device
|
||||
)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
Reference in New Issue
Block a user