Remove global allocator swap, use targeted KV cache managed allocation
sitecustomize.py: No longer swaps CUDAPluggableAllocator globally. Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 instead. vllm_managed_mem.py: No global allocator swap, no torch.cuda patches.
This commit is contained in:
@@ -1,10 +1,14 @@
|
||||
"""
|
||||
sitecustomize.py - Auto-loaded by Python before any other imports.
|
||||
Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator.
|
||||
Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot.
|
||||
|
||||
This MUST run before any torch.cuda calls in any subprocess.
|
||||
Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode).
|
||||
In CMM mode (MANAGED_MEMORY_TOTAL_GB set):
|
||||
- Patches MemorySnapshot.measure() to report managed memory capacity
|
||||
- Patches request_memory to calculate KV cache size based on managed memory
|
||||
- Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 so KV cache uses cudaMallocManaged
|
||||
|
||||
Does NOT swap the global CUDA allocator — model weights and compute
|
||||
intermediates use normal cudaMalloc in HBM. Only KV cache spills into
|
||||
EGM via cudaMallocManaged, called directly from gpu_model_runner.py.
|
||||
|
||||
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
|
||||
"""
|
||||
@@ -14,68 +18,11 @@ import sys
|
||||
# Only activate in CMM mode
|
||||
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||
if _MANAGED_TOTAL:
|
||||
import torch
|
||||
# Enable KV cache managed memory allocation in gpu_model_runner.py
|
||||
os.environ['VLLM_KV_CACHE_USE_MANAGED_MEMORY'] = '1'
|
||||
|
||||
# Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops
|
||||
_lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so')
|
||||
if os.path.exists(_lib_path):
|
||||
try:
|
||||
import ctypes
|
||||
lib = ctypes.CDLL(_lib_path)
|
||||
if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'):
|
||||
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
||||
_lib_path, 'managed_malloc', 'managed_free'
|
||||
)
|
||||
torch.cuda.memory.change_current_allocator(alloc)
|
||||
print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
||||
except Exception as e:
|
||||
print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr)
|
||||
else:
|
||||
print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr)
|
||||
|
||||
# Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support
|
||||
_original_reset_peak = torch.cuda.reset_peak_memory_stats
|
||||
_original_memory_stats = torch.cuda.memory_stats
|
||||
_original_max_memory_allocated = torch.cuda.max_memory_allocated
|
||||
|
||||
def _patched_reset_peak_memory_stats(device=None):
|
||||
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
|
||||
pass
|
||||
|
||||
def _patched_memory_stats(device=None):
|
||||
"""Return minimal stats dict to avoid crashes."""
|
||||
try:
|
||||
return _original_memory_stats(device)
|
||||
except RuntimeError:
|
||||
return {
|
||||
"allocated_bytes.all.current": 0,
|
||||
"allocated_bytes.all.peak": 0,
|
||||
"reserved_bytes.all.current": 0,
|
||||
"reserved_bytes.all.peak": 0,
|
||||
"num_alloc_retries": 0,
|
||||
"num_ooms": 0,
|
||||
}
|
||||
|
||||
def _patched_max_memory_allocated(device=None):
|
||||
"""Return 0 since we can't track peak with pluggable allocator."""
|
||||
try:
|
||||
return _original_max_memory_allocated(device)
|
||||
except RuntimeError:
|
||||
return 0
|
||||
|
||||
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||
torch.cuda.memory_stats = _patched_memory_stats
|
||||
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
|
||||
|
||||
# Patch accelerator aliases (PyTorch 2.11+)
|
||||
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
|
||||
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
||||
if hasattr(torch.accelerator, 'memory_stats'):
|
||||
torch.accelerator.memory_stats = _patched_memory_stats
|
||||
if hasattr(torch.accelerator, 'max_memory_allocated'):
|
||||
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
|
||||
|
||||
# Step 3: Patch MemorySnapshot.measure() to report managed memory
|
||||
# Patch MemorySnapshot.measure() to report managed memory capacity
|
||||
# This tells vLLM how much total memory is available for KV cache sizing
|
||||
try:
|
||||
from vllm.utils.mem_utils import MemorySnapshot
|
||||
_original_measure = MemorySnapshot.measure
|
||||
@@ -90,9 +37,9 @@ if _MANAGED_TOTAL:
|
||||
|
||||
MemorySnapshot.measure = _patched_measure
|
||||
except ImportError:
|
||||
pass # vllm not loaded yet, will be patched by vllm_managed_mem.py
|
||||
pass # vllm not loaded yet, will be patched later
|
||||
|
||||
# Step 4: Patch request_memory to skip free-memory check for managed
|
||||
# Patch request_memory to calculate based on managed memory
|
||||
try:
|
||||
import vllm.v1.worker.utils as worker_utils
|
||||
_original_request_memory = worker_utils.request_memory
|
||||
@@ -111,4 +58,5 @@ if _MANAGED_TOTAL:
|
||||
except ImportError:
|
||||
pass # vllm not loaded yet
|
||||
|
||||
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)
|
||||
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB, "
|
||||
f"KV cache will use cudaMallocManaged)", file=sys.stderr)
|
||||
|
||||
@@ -216,8 +216,13 @@ def patch_vllm_memory_check():
|
||||
|
||||
|
||||
def main():
|
||||
# Step 1: Swap allocator BEFORE any CUDA ops
|
||||
swap_allocator()
|
||||
# Step 1: NO global allocator swap — model weights stay in HBM.
|
||||
# KV cache uses cudaMallocManaged directly via
|
||||
# VLLM_KV_CACHE_USE_MANAGED_MEMORY env var (set by sitecustomize.py).
|
||||
# The global allocator swap broke cuBLAS GEMM operations because
|
||||
# intermediate compute tensors ended up in managed memory.
|
||||
print(f"[managed_mem] Using targeted KV cache managed allocation "
|
||||
f"(no global allocator swap)", file=sys.stderr)
|
||||
|
||||
# Step 2: Calculate total managed memory and export it
|
||||
total_managed_gb = get_total_managed_memory_gb()
|
||||
@@ -231,8 +236,8 @@ def main():
|
||||
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Step 3: Patch PyTorch memory tracking (pluggable allocator doesn't support all ops)
|
||||
patch_torch_memory_tracking()
|
||||
# Step 3: No torch.cuda memory tracking patches needed —
|
||||
# we're not using CUDAPluggableAllocator anymore.
|
||||
|
||||
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
|
||||
# This is critical - without it, all downstream code only sees HBM
|
||||
|
||||
Reference in New Issue
Block a user