diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a1af0b8ae..f82ec5d2b 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import gc import json import os import time @@ -26,6 +27,46 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser FP8_DTYPE = current_platform.fp8_dtype() +# Default interval for clearing Triton JIT cache during tuning +# Set to 0 to disable automatic cache clearing +_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL" +TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50")) + + +def clear_triton_cache(): + """Clear Triton JIT compilation cache and Python/CUDA memory. + + This helps prevent OOM during tuning with large models (many experts). + """ + # Force Python garbage collection + gc.collect() + + # Clear CUDA memory cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Try to clear Triton's runtime cache + try: + import triton + + if ( + hasattr(triton, "runtime") + and hasattr(triton.runtime, "cache") + and hasattr(triton.runtime.cache, "clear") + ): + triton.runtime.cache.clear() + except ImportError: + # Triton not installed, skip cache clearing + pass + except AttributeError: + # Triton version doesn't have expected cache API + pass + except Exception as e: + print(f"Warning: Failed to clear Triton cache: {e}") + + # Additional garbage collection after clearing caches + gc.collect() + def ensure_divisibility(numerator, denominator, text): """Ensure that numerator is divisible by the denominator.""" @@ -483,7 +524,7 @@ class BenchmarkWorker: need_device_guard = True with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): - for config in tqdm(search_space): + for idx, config in enumerate(tqdm(search_space)): try: kernel_time = benchmark_config( config, @@ -506,6 +547,19 @@ class BenchmarkWorker: if kernel_time < best_time: best_time = kernel_time best_config = config + + # Periodically clear Triton JIT cache to prevent OOM + # This is especially important for large models with many experts + if ( + TRITON_CACHE_CLEAR_INTERVAL > 0 + and idx > 0 + and idx % TRITON_CACHE_CLEAR_INTERVAL == 0 + ): + clear_triton_cache() + + # Final cleanup after tuning completes + clear_triton_cache() + now = datetime.now() print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}") assert best_config is not None