Remove global allocator swap, use targeted KV cache managed allocation

sitecustomize.py: No longer swaps CUDAPluggableAllocator globally.
Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 instead.
vllm_managed_mem.py: No global allocator swap, no torch.cuda patches.
This commit is contained in:
2026-04-11 02:15:09 +00:00
parent 07468031db
commit bcc872c2c3
2 changed files with 25 additions and 72 deletions

View File

@@ -1,10 +1,14 @@
"""
sitecustomize.py - Auto-loaded by Python before any other imports.
Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator.
Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot.
This MUST run before any torch.cuda calls in any subprocess.
Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode).
In CMM mode (MANAGED_MEMORY_TOTAL_GB set):
- Patches MemorySnapshot.measure() to report managed memory capacity
- Patches request_memory to calculate KV cache size based on managed memory
- Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 so KV cache uses cudaMallocManaged
Does NOT swap the global CUDA allocator — model weights and compute
intermediates use normal cudaMalloc in HBM. Only KV cache spills into
EGM via cudaMallocManaged, called directly from gpu_model_runner.py.
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
"""
@@ -14,68 +18,11 @@ import sys
# Only activate in CMM mode
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if _MANAGED_TOTAL:
import torch
# Enable KV cache managed memory allocation in gpu_model_runner.py
os.environ['VLLM_KV_CACHE_USE_MANAGED_MEMORY'] = '1'
# Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops
_lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so')
if os.path.exists(_lib_path):
try:
import ctypes
lib = ctypes.CDLL(_lib_path)
if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'):
alloc = torch.cuda.memory.CUDAPluggableAllocator(
_lib_path, 'managed_malloc', 'managed_free'
)
torch.cuda.memory.change_current_allocator(alloc)
print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr)
except Exception as e:
print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr)
else:
print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr)
# Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support
_original_reset_peak = torch.cuda.reset_peak_memory_stats
_original_memory_stats = torch.cuda.memory_stats
_original_max_memory_allocated = torch.cuda.max_memory_allocated
def _patched_reset_peak_memory_stats(device=None):
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
pass
def _patched_memory_stats(device=None):
"""Return minimal stats dict to avoid crashes."""
try:
return _original_memory_stats(device)
except RuntimeError:
return {
"allocated_bytes.all.current": 0,
"allocated_bytes.all.peak": 0,
"reserved_bytes.all.current": 0,
"reserved_bytes.all.peak": 0,
"num_alloc_retries": 0,
"num_ooms": 0,
}
def _patched_max_memory_allocated(device=None):
"""Return 0 since we can't track peak with pluggable allocator."""
try:
return _original_max_memory_allocated(device)
except RuntimeError:
return 0
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
torch.cuda.memory_stats = _patched_memory_stats
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
# Patch accelerator aliases (PyTorch 2.11+)
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
if hasattr(torch.accelerator, 'memory_stats'):
torch.accelerator.memory_stats = _patched_memory_stats
if hasattr(torch.accelerator, 'max_memory_allocated'):
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
# Step 3: Patch MemorySnapshot.measure() to report managed memory
# Patch MemorySnapshot.measure() to report managed memory capacity
# This tells vLLM how much total memory is available for KV cache sizing
try:
from vllm.utils.mem_utils import MemorySnapshot
_original_measure = MemorySnapshot.measure
@@ -90,9 +37,9 @@ if _MANAGED_TOTAL:
MemorySnapshot.measure = _patched_measure
except ImportError:
pass # vllm not loaded yet, will be patched by vllm_managed_mem.py
pass # vllm not loaded yet, will be patched later
# Step 4: Patch request_memory to skip free-memory check for managed
# Patch request_memory to calculate based on managed memory
try:
import vllm.v1.worker.utils as worker_utils
_original_request_memory = worker_utils.request_memory
@@ -111,4 +58,5 @@ if _MANAGED_TOTAL:
except ImportError:
pass # vllm not loaded yet
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB, "
f"KV cache will use cudaMallocManaged)", file=sys.stderr)

View File

@@ -216,8 +216,13 @@ def patch_vllm_memory_check():
def main():
# Step 1: Swap allocator BEFORE any CUDA ops
swap_allocator()
# Step 1: NO global allocator swap — model weights stay in HBM.
# KV cache uses cudaMallocManaged directly via
# VLLM_KV_CACHE_USE_MANAGED_MEMORY env var (set by sitecustomize.py).
# The global allocator swap broke cuBLAS GEMM operations because
# intermediate compute tensors ended up in managed memory.
print(f"[managed_mem] Using targeted KV cache managed allocation "
f"(no global allocator swap)", file=sys.stderr)
# Step 2: Calculate total managed memory and export it
total_managed_gb = get_total_managed_memory_gb()
@@ -231,8 +236,8 @@ def main():
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
file=sys.stderr)
# Step 3: Patch PyTorch memory tracking (pluggable allocator doesn't support all ops)
patch_torch_memory_tracking()
# Step 3: No torch.cuda memory tracking patches needed —
# we're not using CUDAPluggableAllocator anymore.
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
# This is critical - without it, all downstream code only sees HBM