Fix managed KV cache: use __cuda_array_interface__ instead of UntypedStorage.from_blob

UntypedStorage.from_blob was removed in PyTorch 2.11+. Use the
standard __cuda_array_interface__ protocol to wrap cudaMallocManaged
pointers into PyTorch tensors — this works across all PyTorch versions.

Also removed cudaMemAdvise calls — ctypes struct passing for
cudaMemLocation is broken on ARM64 (returns EINVAL). The advise hints
are optional; pages will page-fault to GPU on-demand regardless.

CPU memset (ctypes.memset) is still used instead of cudaMemset to
avoid forcing all pages into HBM during zeroing.
This commit is contained in:
2026-04-12 06:56:52 +00:00
parent c77342da87
commit 013b73e9b2

View File

@@ -223,16 +223,17 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
cudaMalloc — only the KV cache uses managed memory.
Key design decisions for KV cache (different from model weights):
- Preferred location is CPU (LPDDR), NOT GPU. The KV cache is too
large to fit in HBM (often 50-100+ GiB). Setting preferred location
to GPU would cause the driver to try migrating the entire allocation
to HBM, resulting in OOM. Instead, let pages reside in LPDDR and
page-fault to GPU on-demand during attention operations.
- CPU access is set via AccessedBy to avoid page migration back to
system RAM on CPU reads (GH200 uses C2C NVLink for remote access).
- No cudaMemAdviseSetPreferredLocation(GPU). The KV cache is too
large for HBM (often 50-100+ GiB). Setting preferred location to
GPU would cause the driver to try migrating the entire allocation
to HBM, resulting in OOM. Pages page-fault to GPU on-demand during
attention operations and are evicted back to LPDDR when HBM is
needed.
- Zeroing is done via CPU memset, NOT cudaMemset. cudaMemset executes
on the device, which forces all pages to be migrated to GPU first —
on the device, which forces ALL pages to be migrated to GPU first —
exactly what we're trying to avoid. CPU memset leaves pages in LPDDR.
- Tensor wrapping uses __cuda_array_interface__ instead of the
deprecated UntypedStorage.from_blob (removed in PyTorch 2.11+).
"""
cuda = ctypes.CDLL("libcudart.so")
@@ -247,36 +248,6 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
f"cudaMallocManaged failed for KV cache ({size} bytes / "
f"{size / 1024**3:.2f} GiB): {err_str.decode()}")
dev_idx = device.index if device.index is not None else 0
# cudaMemAdvise: prefer CPU placement (cudaMemAdviseSetPreferredLocation=3).
# Unlike model weights (which are small and need to be on GPU for GEMM),
# the KV cache is far too large for HBM. Setting preferred location to
# CPU keeps the pages in LPDDR/EGM. They will page-fault to GPU only
# when the attention kernel accesses them, and the driver will evict
# them back to LPDDR when HBM is needed for other allocations.
# cudaMemLocation struct: {type=1(host), id=-1(cudaCpuDeviceId)}
cpu_loc = (ctypes.c_int * 2)(1, -1)
advise_err = cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size),
ctypes.c_int(3), cpu_loc)
if advise_err != 0:
logger.warning("cudaMemAdvise SetPreferredLocation(CPU) failed "
"(err=%d), KV cache pages may not stay in LPDDR",
advise_err)
# cudaMemAdvise: GPU access without migration (cudaMemAdviseSetAccessedBy=9).
# This tells the driver that the GPU will read these pages, but should
# NOT migrate them to GPU on access — instead, the GPU accesses them
# remotely over C2C NVLink. This prevents the KV cache from evicting
# model weights and compute intermediates from HBM.
# cudaMemLocation struct: {type=0(device), id=dev_idx}
gpu_loc = (ctypes.c_int * 2)(0, dev_idx)
advise_err = cuda.cudaMemAdvise(ptr, ctypes.c_size_t(size),
ctypes.c_int(9), gpu_loc)
if advise_err != 0:
logger.warning("cudaMemAdvise SetAccessedBy(GPU) failed "
"(err=%d), GPU may trigger page migrations", advise_err)
# Zero out the managed memory via CPU, NOT via cudaMemset.
# cudaMemset runs on the device, which forces ALL pages to be migrated
# to GPU before zeroing — defeating the entire purpose of keeping the
@@ -285,20 +256,33 @@ def _allocate_managed_kv_cache(size: int, device: torch.device) -> torch.Tensor:
# only when the attention kernel actually reads them.
ctypes.memset(ptr, 0, size)
# Wrap as a PyTorch tensor using UntypedStorage.from_blob.
# We use from_blob with a no-op allocator so PyTorch doesn't try to
# free the managed memory through its own allocator — we manage the
# lifetime ourselves (cudaFree happens when the tensor is garbage
# collected via the normal reference counting path).
storage = torch.UntypedStorage.from_blob(
ptr, size_bytes=size,
allocator=lambda: None # no-op allocator, we manage lifetime
)
tensor = torch.tensor([], dtype=torch.int8, device=device)
tensor.set_(storage, 0, [size])
# Wrap the managed memory pointer as a PyTorch tensor using
# __cuda_array_interface__. This is the standard mechanism for creating
# tensors from external CUDA pointers, and works across PyTorch versions
# (UntypedStorage.from_blob was removed in PyTorch 2.11+).
# The tensor does NOT own the memory — we handle cudaFree ourselves
# when the KV cache is destroyed (process lifetime).
class _ManagedMemoryWrapper:
"""Wrapper exposing __cuda_array_interface__ for a cudaMallocManaged pointer."""
def __init__(self, ptr_value, nbytes):
self._ptr = ptr_value
self._nbytes = nbytes
@property
def __cuda_array_interface__(self):
return {
"data": (self._ptr, False), # (ptr, readonly)
"shape": (self._nbytes,),
"strides": None,
"typestr": "|i1", # int8 bytes
"version": 3,
}
wrapper = _ManagedMemoryWrapper(ptr.value, size)
tensor = torch.as_tensor(wrapper, device=device)
logger.info("Allocated KV cache via cudaMallocManaged: %.2f GiB "
"(preferred=CPU, on-demand page-fault to GPU)",
"(CPU memset, on-demand page-fault to GPU)",
size / 1024**3)
return tensor