[Benchmark] Fix OOM during MoE kernel tuning for large models (#31604)

Signed-off-by: Alfred <massif0601@gmail.com>
This commit is contained in:
Alfred
2026-01-03 06:24:51 +08:00
committed by GitHub
parent a3f2f40947
commit a0e9ee83c7

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import gc
import json
import os
import time
@@ -26,6 +27,46 @@ from vllm.utils.argparse_utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
# Default interval for clearing Triton JIT cache during tuning
# Set to 0 to disable automatic cache clearing
_CACHE_CLEAR_INTERVAL_ENV = "VLLM_MOE_TUNE_CACHE_CLEAR_INTERVAL"
TRITON_CACHE_CLEAR_INTERVAL = int(os.environ.get(_CACHE_CLEAR_INTERVAL_ENV, "50"))
def clear_triton_cache():
"""Clear Triton JIT compilation cache and Python/CUDA memory.
This helps prevent OOM during tuning with large models (many experts).
"""
# Force Python garbage collection
gc.collect()
# Clear CUDA memory cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Try to clear Triton's runtime cache
try:
import triton
if (
hasattr(triton, "runtime")
and hasattr(triton.runtime, "cache")
and hasattr(triton.runtime.cache, "clear")
):
triton.runtime.cache.clear()
except ImportError:
# Triton not installed, skip cache clearing
pass
except AttributeError:
# Triton version doesn't have expected cache API
pass
except Exception as e:
print(f"Warning: Failed to clear Triton cache: {e}")
# Additional garbage collection after clearing caches
gc.collect()
def ensure_divisibility(numerator, denominator, text):
"""Ensure that numerator is divisible by the denominator."""
@@ -483,7 +524,7 @@ class BenchmarkWorker:
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space):
for idx, config in enumerate(tqdm(search_space)):
try:
kernel_time = benchmark_config(
config,
@@ -506,6 +547,19 @@ class BenchmarkWorker:
if kernel_time < best_time:
best_time = kernel_time
best_config = config
# Periodically clear Triton JIT cache to prevent OOM
# This is especially important for large models with many experts
if (
TRITON_CACHE_CLEAR_INTERVAL > 0
and idx > 0
and idx % TRITON_CACHE_CLEAR_INTERVAL == 0
):
clear_triton_cache()
# Final cleanup after tuning completes
clear_triton_cache()
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None