[Benchmark] Fix OOM during MoE kernel tuning for large models (#31604)
Signed-off-by: Alfred <massif0601@gmail.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -26,6 +27,46 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
# Default interval for clearing Triton JIT cache during tuning
|
||||
# Set to 0 to disable automatic cache clearing
|
||||
_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
|
||||
TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50"))
|
||||
|
||||
|
||||
def clear_triton_cache():
|
||||
"""Clear Triton JIT compilation cache and Python/CUDA memory.
|
||||
|
||||
This helps prevent OOM during tuning with large models (many experts).
|
||||
"""
|
||||
# Force Python garbage collection
|
||||
gc.collect()
|
||||
|
||||
# Clear CUDA memory cache
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Try to clear Triton's runtime cache
|
||||
try:
|
||||
import triton
|
||||
|
||||
if (
|
||||
hasattr(triton, "runtime")
|
||||
and hasattr(triton.runtime, "cache")
|
||||
and hasattr(triton.runtime.cache, "clear")
|
||||
):
|
||||
triton.runtime.cache.clear()
|
||||
except ImportError:
|
||||
# Triton not installed, skip cache clearing
|
||||
pass
|
||||
except AttributeError:
|
||||
# Triton version doesn't have expected cache API
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clear Triton cache: {e}")
|
||||
|
||||
# Additional garbage collection after clearing caches
|
||||
gc.collect()
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator, text):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
@@ -483,7 +524,7 @@ class BenchmarkWorker:
|
||||
need_device_guard = True
|
||||
|
||||
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
|
||||
for config in tqdm(search_space):
|
||||
for idx, config in enumerate(tqdm(search_space)):
|
||||
try:
|
||||
kernel_time = benchmark_config(
|
||||
config,
|
||||
@@ -506,6 +547,19 @@ class BenchmarkWorker:
|
||||
if kernel_time < best_time:
|
||||
best_time = kernel_time
|
||||
best_config = config
|
||||
|
||||
# Periodically clear Triton JIT cache to prevent OOM
|
||||
# This is especially important for large models with many experts
|
||||
if (
|
||||
TRITON_CACHE_CLEAR_INTERVAL > 0
|
||||
and idx > 0
|
||||
and idx % TRITON_CACHE_CLEAR_INTERVAL == 0
|
||||
):
|
||||
clear_triton_cache()
|
||||
|
||||
# Final cleanup after tuning completes
|
||||
clear_triton_cache()
|
||||
|
||||
now = datetime.now()
|
||||
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
|
||||
assert best_config is not None
|
||||
|
||||
Reference in New Issue
Block a user