Remove global allocator swap, use targeted KV cache managed allocation
sitecustomize.py: No longer swaps CUDAPluggableAllocator globally. Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 instead. vllm_managed_mem.py: No global allocator swap, no torch.cuda patches.
This commit is contained in:
@@ -1,10 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
sitecustomize.py - Auto-loaded by Python before any other imports.
|
sitecustomize.py - Auto-loaded by Python before any other imports.
|
||||||
Patches PyTorch's CUDA memory tracking to work with CUDAPluggableAllocator.
|
|
||||||
Also swaps the allocator to cudaMallocManaged and patches MemorySnapshot.
|
|
||||||
|
|
||||||
This MUST run before any torch.cuda calls in any subprocess.
|
In CMM mode (MANAGED_MEMORY_TOTAL_GB set):
|
||||||
Only activates when MANAGED_MEMORY_TOTAL_GB is set (CMM mode).
|
- Patches MemorySnapshot.measure() to report managed memory capacity
|
||||||
|
- Patches request_memory to calculate KV cache size based on managed memory
|
||||||
|
- Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 so KV cache uses cudaMallocManaged
|
||||||
|
|
||||||
|
Does NOT swap the global CUDA allocator — model weights and compute
|
||||||
|
intermediates use normal cudaMalloc in HBM. Only KV cache spills into
|
||||||
|
EGM via cudaMallocManaged, called directly from gpu_model_runner.py.
|
||||||
|
|
||||||
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
|
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
|
||||||
"""
|
"""
|
||||||
@@ -14,68 +18,11 @@ import sys
|
|||||||
# Only activate in CMM mode
|
# Only activate in CMM mode
|
||||||
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
|
||||||
if _MANAGED_TOTAL:
|
if _MANAGED_TOTAL:
|
||||||
import torch
|
# Enable KV cache managed memory allocation in gpu_model_runner.py
|
||||||
|
os.environ['VLLM_KV_CACHE_USE_MANAGED_MEMORY'] = '1'
|
||||||
|
|
||||||
# Step 1: Swap allocator to cudaMallocManaged BEFORE any CUDA ops
|
# Patch MemorySnapshot.measure() to report managed memory capacity
|
||||||
_lib_path = os.environ.get('MANAGED_ALLOC_LIB', '/usr/local/lib/libmanaged_alloc.so')
|
# This tells vLLM how much total memory is available for KV cache sizing
|
||||||
if os.path.exists(_lib_path):
|
|
||||||
try:
|
|
||||||
import ctypes
|
|
||||||
lib = ctypes.CDLL(_lib_path)
|
|
||||||
if hasattr(lib, 'managed_malloc') and hasattr(lib, 'managed_free'):
|
|
||||||
alloc = torch.cuda.memory.CUDAPluggableAllocator(
|
|
||||||
_lib_path, 'managed_malloc', 'managed_free'
|
|
||||||
)
|
|
||||||
torch.cuda.memory.change_current_allocator(alloc)
|
|
||||||
print(f"[sitecustomize] Allocator swapped to cudaMallocManaged", file=sys.stderr)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[sitecustomize] WARNING: Failed to swap allocator: {e}", file=sys.stderr)
|
|
||||||
else:
|
|
||||||
print(f"[sitecustomize] WARNING: {_lib_path} not found", file=sys.stderr)
|
|
||||||
|
|
||||||
# Step 2: Patch torch.cuda functions that CUDAPluggableAllocator doesn't support
|
|
||||||
_original_reset_peak = torch.cuda.reset_peak_memory_stats
|
|
||||||
_original_memory_stats = torch.cuda.memory_stats
|
|
||||||
_original_max_memory_allocated = torch.cuda.max_memory_allocated
|
|
||||||
|
|
||||||
def _patched_reset_peak_memory_stats(device=None):
|
|
||||||
"""No-op: CUDAPluggableAllocator doesn't support resetPeakStats."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _patched_memory_stats(device=None):
|
|
||||||
"""Return minimal stats dict to avoid crashes."""
|
|
||||||
try:
|
|
||||||
return _original_memory_stats(device)
|
|
||||||
except RuntimeError:
|
|
||||||
return {
|
|
||||||
"allocated_bytes.all.current": 0,
|
|
||||||
"allocated_bytes.all.peak": 0,
|
|
||||||
"reserved_bytes.all.current": 0,
|
|
||||||
"reserved_bytes.all.peak": 0,
|
|
||||||
"num_alloc_retries": 0,
|
|
||||||
"num_ooms": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _patched_max_memory_allocated(device=None):
|
|
||||||
"""Return 0 since we can't track peak with pluggable allocator."""
|
|
||||||
try:
|
|
||||||
return _original_max_memory_allocated(device)
|
|
||||||
except RuntimeError:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
torch.cuda.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
|
||||||
torch.cuda.memory_stats = _patched_memory_stats
|
|
||||||
torch.cuda.max_memory_allocated = _patched_max_memory_allocated
|
|
||||||
|
|
||||||
# Patch accelerator aliases (PyTorch 2.11+)
|
|
||||||
if hasattr(torch.accelerator, 'reset_peak_memory_stats'):
|
|
||||||
torch.accelerator.reset_peak_memory_stats = _patched_reset_peak_memory_stats
|
|
||||||
if hasattr(torch.accelerator, 'memory_stats'):
|
|
||||||
torch.accelerator.memory_stats = _patched_memory_stats
|
|
||||||
if hasattr(torch.accelerator, 'max_memory_allocated'):
|
|
||||||
torch.accelerator.max_memory_allocated = _patched_max_memory_allocated
|
|
||||||
|
|
||||||
# Step 3: Patch MemorySnapshot.measure() to report managed memory
|
|
||||||
try:
|
try:
|
||||||
from vllm.utils.mem_utils import MemorySnapshot
|
from vllm.utils.mem_utils import MemorySnapshot
|
||||||
_original_measure = MemorySnapshot.measure
|
_original_measure = MemorySnapshot.measure
|
||||||
@@ -90,9 +37,9 @@ if _MANAGED_TOTAL:
|
|||||||
|
|
||||||
MemorySnapshot.measure = _patched_measure
|
MemorySnapshot.measure = _patched_measure
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass # vllm not loaded yet, will be patched by vllm_managed_mem.py
|
pass # vllm not loaded yet, will be patched later
|
||||||
|
|
||||||
# Step 4: Patch request_memory to skip free-memory check for managed
|
# Patch request_memory to calculate based on managed memory
|
||||||
try:
|
try:
|
||||||
import vllm.v1.worker.utils as worker_utils
|
import vllm.v1.worker.utils as worker_utils
|
||||||
_original_request_memory = worker_utils.request_memory
|
_original_request_memory = worker_utils.request_memory
|
||||||
@@ -111,4 +58,5 @@ if _MANAGED_TOTAL:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass # vllm not loaded yet
|
pass # vllm not loaded yet
|
||||||
|
|
||||||
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB)", file=sys.stderr)
|
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB, "
|
||||||
|
f"KV cache will use cudaMallocManaged)", file=sys.stderr)
|
||||||
|
|||||||
@@ -216,8 +216,13 @@ def patch_vllm_memory_check():
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Step 1: Swap allocator BEFORE any CUDA ops
|
# Step 1: NO global allocator swap — model weights stay in HBM.
|
||||||
swap_allocator()
|
# KV cache uses cudaMallocManaged directly via
|
||||||
|
# VLLM_KV_CACHE_USE_MANAGED_MEMORY env var (set by sitecustomize.py).
|
||||||
|
# The global allocator swap broke cuBLAS GEMM operations because
|
||||||
|
# intermediate compute tensors ended up in managed memory.
|
||||||
|
print(f"[managed_mem] Using targeted KV cache managed allocation "
|
||||||
|
f"(no global allocator swap)", file=sys.stderr)
|
||||||
|
|
||||||
# Step 2: Calculate total managed memory and export it
|
# Step 2: Calculate total managed memory and export it
|
||||||
total_managed_gb = get_total_managed_memory_gb()
|
total_managed_gb = get_total_managed_memory_gb()
|
||||||
@@ -231,8 +236,8 @@ def main():
|
|||||||
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
|
print(f"[managed_mem] MANAGED_MEMORY_TOTAL_GB={total_managed_gb:.0f}",
|
||||||
file=sys.stderr)
|
file=sys.stderr)
|
||||||
|
|
||||||
# Step 3: Patch PyTorch memory tracking (pluggable allocator doesn't support all ops)
|
# Step 3: No torch.cuda memory tracking patches needed —
|
||||||
patch_torch_memory_tracking()
|
# we're not using CUDAPluggableAllocator anymore.
|
||||||
|
|
||||||
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
|
# Step 4: Patch MemorySnapshot.measure() to report full managed memory
|
||||||
# This is critical - without it, all downstream code only sees HBM
|
# This is critical - without it, all downstream code only sees HBM
|
||||||
|
|||||||
Reference in New Issue
Block a user