Fix managed KV cache: use __cuda_array_interface__ instead of UntypedStorage.from_blob
UntypedStorage.from_blob was removed in PyTorch 2.11+. Use the standard __cuda_array_interface__ protocol to wrap cudaMallocManaged pointers into PyTorch tensors — this works across all PyTorch versions. Also removed cudaMemAdvise calls — ctypes struct passing for cudaMemLocation is broken on ARM64 (returns EINVAL). The advise hints are optional; pages will page-fault to GPU on-demand regardless. CPU memset (ctypes.memset) is still used instead of cudaMemset to avoid forcing all pages into HBM during zeroing.
This commit is contained in:
@@ -223,16 +223,17 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
|
||||
cudaMalloc — only the KV cache uses managed memory.
|
||||
|
||||
Key design decisions for KV cache (different from model weights):
|
||||
- Preferred location is CPU (LPDDR), NOT GPU. The KV cache is too
|
||||
large to fit in HBM (often 50-100+ GiB). Setting preferred location
|
||||
to GPU would cause the driver to try migrating the entire allocation
|
||||
to HBM, resulting in OOM. Instead, let pages reside in LPDDR and
|
||||
page-fault to GPU on-demand during attention operations.
|
||||
- CPU access is set via AccessedBy to avoid page migration back to
|
||||
system RAM on CPU reads (GH200 uses C2C NVLink for remote access).
|
||||
- No cudaMemAdviseSetPreferredLocation(GPU). The KV cache is too
|
||||
large for HBM (often 50-100+ GiB). Setting preferred location to
|
||||
GPU would cause the driver to try migrating the entire allocation
|
||||
to HBM, resulting in OOM. Pages page-fault to GPU on-demand during
|
||||
attention operations and are evicted back to LPDDR when HBM is
|
||||
needed.
|
||||
- Zeroing is done via CPU memset, NOT cudaMemset. cudaMemset executes
|
||||
on the device, which forces all pages to be migrated to GPU first —
|
||||
on the device, which forces ALL pages to be migrated to GPU first —
|
||||
exactly what we're trying to avoid. CPU memset leaves pages in LPDDR.
|
||||
- Tensor wrapping uses __cuda_array_interface__ instead of the
|
||||
deprecated UntypedStorage.from_blob (removed in PyTorch 2.11+).
|
||||
"""
|
||||
cuda = ctypes.CDLL("libcudart.so")
|
||||
|
||||
@@ -247,36 +248,6 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
|
||||
f"cudaMallocManaged failed for KV cache ({size} bytes / "
|
||||
f"{size / 1024**3:.2f} GiB): {err_str.decode()}")
|
||||
|
||||
dev_idx = device.index if device.index is not None else 0
|
||||
|
||||
# cudaMemAdvise: prefer CPU placement (cudaMemAdviseSetPreferredLocation=3).
|
||||
# Unlike model weights (which are small and need to be on GPU for GEMM),
|
||||
# the KV cache is far too large for HBM. Setting preferred location to
|
||||
# CPU keeps the pages in LPDDR/EGM. They will page-fault to GPU only
|
||||
# when the attention kernel accesses them, and the driver will evict
|
||||
# them back to LPDDR when HBM is needed for other allocations.
|
||||
# cudaMemLocation struct: {type=1(host), id=-1(cudaCpuDeviceId)}
|
||||
cpu_loc = (ctypes.c_int * 2)(1, -1)
|
||||
advise_err = cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size),
|
||||
ctypes.c_int(3), cpu_loc)
|
||||
if advise_err != 0:
|
||||
logger.warning("cudaMemAdvise SetPreferredLocation(CPU) failed "
|
||||
"(err=%d), KV cache pages may not stay in LPDDR",
|
||||
advise_err)
|
||||
|
||||
# cudaMemAdvise: GPU access without migration (cudaMemAdviseSetAccessedBy=9).
|
||||
# This tells the driver that the GPU will read these pages, but should
|
||||
# NOT migrate them to GPU on access — instead, the GPU accesses them
|
||||
# remotely over C2C NVLink. This prevents the KV cache from evicting
|
||||
# model weights and compute intermediates from HBM.
|
||||
# cudaMemLocation struct: {type=0(device), id=dev_idx}
|
||||
gpu_loc = (ctypes.c_int * 2)(0, dev_idx)
|
||||
advise_err = cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size),
|
||||
ctypes.c_int(9), gpu_loc)
|
||||
if advise_err != 0:
|
||||
logger.warning("cudaMemAdvise SetAccessedBy(GPU) failed "
|
||||
"(err=%d), GPU may trigger page migrations", advise_err)
|
||||
|
||||
# Zero out the managed memory via CPU, NOT via cudaMemset.
|
||||
# cudaMemset runs on the device, which forces ALL pages to be migrated
|
||||
# to GPU before zeroing — defeating the entire purpose of keeping the
|
||||
@@ -285,20 +256,33 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
|
||||
# only when the attention kernel actually reads them.
|
||||
ctypes.memset(ptr, 0, size)
|
||||
|
||||
# Wrap as a PyTorch tensor using UntypedStorage.from_blob.
|
||||
# We use from_blob with a no-op allocator so PyTorch doesn't try to
|
||||
# free the managed memory through its own allocator — we manage the
|
||||
# lifetime ourselves (cudaFree happens when the tensor is garbage
|
||||
# collected via the normal reference counting path).
|
||||
storage = torch.UntypedStorage.from_blob(
|
||||
ptr, size_bytes=size,
|
||||
allocator=lambda: None # no-op allocator, we manage lifetime
|
||||
)
|
||||
tensor = torch.tensor([], dtype=torch.int8, device=device)
|
||||
tensor.set_(storage, 0, [size])
|
||||
# Wrap the managed memory pointer as a PyTorch tensor using
|
||||
# __cuda_array_interface__. This is the standard mechanism for creating
|
||||
# tensors from external CUDA pointers, and works across PyTorch versions
|
||||
# (UntypedStorage.from_blob was removed in PyTorch 2.11+).
|
||||
# The tensor does NOT own the memory — we handle cudaFree ourselves
|
||||
# when the KV cache is destroyed (process lifetime).
|
||||
class _ManagedMemoryWrapper:
|
||||
"""Wrapper exposing __cuda_array_interface__ for a cudaMallocManaged pointer."""
|
||||
def __init__(self, ptr_value, nbytes):
|
||||
self._ptr = ptr_value
|
||||
self._nbytes = nbytes
|
||||
|
||||
@property
|
||||
def __cuda_array_interface__(self):
|
||||
return {
|
||||
"data": (self._ptr, False), # (ptr, readonly)
|
||||
"shape": (self._nbytes,),
|
||||
"strides": None,
|
||||
"typestr": "|i1", # int8 bytes
|
||||
"version": 3,
|
||||
}
|
||||
|
||||
wrapper = _ManagedMemoryWrapper(ptr.value, size)
|
||||
tensor = torch.as_tensor(wrapper, device=device)
|
||||
|
||||
logger.info("Allocated KV cache via cudaMallocManaged: %.2f GiB "
|
||||
"(preferred=CPU, on-demand page-fault to GPU)",
|
||||
"(CPU memset, on-demand page-fault to GPU)",
|
||||
size / 1024**3)
|
||||
|
||||
return tensor
|
||||
|
||||
Reference in New Issue
Block a user