Files
grace-gpu-containers/vllm/sitecustomize.py

63 lines
2.5 KiB
Python
Raw Normal View History

"""
sitecustomize.py - Auto-loaded by Python before any other imports.
In CMM mode (MANAGED_MEMORY_TOTAL_GB set):
- Patches MemorySnapshot.measure() to report managed memory capacity
- Patches request_memory to calculate KV cache size based on managed memory
- Sets VLLM_KV_CACHE_USE_MANAGED_MEMORY=1 so KV cache uses cudaMallocManaged
Does NOT swap the global CUDA allocator model weights and compute
intermediates use normal cudaMalloc in HBM. Only KV cache spills into
EGM via cudaMallocManaged, called directly from gpu_model_runner.py.
Installed at: /usr/local/lib/python3.12/dist-packages/sitecustomize.py
"""
import os
import sys
# Only activate in CMM mode
_MANAGED_TOTAL = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if _MANAGED_TOTAL:
# Enable KV cache managed memory allocation in gpu_model_runner.py
os.environ['VLLM_KV_CACHE_USE_MANAGED_MEMORY'] = '1'
# Patch MemorySnapshot.measure() to report managed memory capacity
# This tells vLLM how much total memory is available for KV cache sizing
try:
from vllm.utils.mem_utils import MemorySnapshot
_original_measure = MemorySnapshot.measure
def _patched_measure(self):
_original_measure(self)
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
managed_total = int(float(managed_total_gb) * (1024**3))
self.total_memory = managed_total
self.free_memory = managed_total - self.cuda_memory
MemorySnapshot.measure = _patched_measure
except ImportError:
pass # vllm not loaded yet, will be patched later
# Patch request_memory to calculate based on managed memory
try:
import vllm.v1.worker.utils as worker_utils
_original_request_memory = worker_utils.request_memory
def _patched_request_memory(init_snapshot, cache_config):
managed_total_gb = os.environ.get('MANAGED_MEMORY_TOTAL_GB')
if managed_total_gb:
total_bytes = float(managed_total_gb) * (1024**3)
gpu_util = cache_config.gpu_memory_utilization
requested = int(total_bytes * gpu_util)
return requested
else:
return _original_request_memory(init_snapshot, cache_config)
worker_utils.request_memory = _patched_request_memory
except ImportError:
pass # vllm not loaded yet
print(f"[sitecustomize] CMM patches applied (managed={_MANAGED_TOTAL} GiB, "
f"KV cache will use cudaMallocManaged)", file=sys.stderr)