Add PyTorch profiler schedule support with warmup/active iterations (#35240)

This commit is contained in:
fenypatel99
2026-03-04 12:53:38 -08:00
committed by GitHub
parent 636ee223ac
commit 7eca859110
2 changed files with 86 additions and 3 deletions

View File

@@ -45,8 +45,10 @@ class ProfilerConfig:
worker's traces (CPU & GPU) will be saved under this directory. Note that
it must be an absolute path."""
torch_profiler_with_stack: bool = True
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
torch_profiler_with_stack: bool = False
"""If `True`, enables stack tracing in the torch profiler. Disabled by default
to reduce overhead. Can be enabled via VLLM_TORCH_PROFILER_WITH_STACK=1 env var
or --profiler-config.torch_profiler_with_stack=true CLI flag."""
torch_profiler_with_flops: bool = False
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
@@ -81,6 +83,27 @@ class ProfilerConfig:
Defaults to 0, meaning no limit.
"""
warmup_iterations: int = Field(default=0, ge=0)
"""Number of warmup iterations for PyTorch profiler schedule.
During warmup, the profiler runs but data is discarded. This helps reduce
noise from JIT compilation and other one-time costs in the profiled trace.
Defaults to 0 (schedule-based profiling disabled, recording all iterations).
Set to a positive value (e.g., 2) to enable schedule-based profiling.
"""
active_iterations: int = Field(default=5, ge=1)
"""Number of active iterations for PyTorch profiler schedule.
This is the number of iterations where profiling data is actually collected.
Defaults to 5 active iterations.
"""
wait_iterations: int = Field(default=0, ge=0)
"""Number of wait iterations for PyTorch profiler schedule.
During wait, the profiler is completely off with zero overhead.
This allows skipping initial iterations before warmup begins.
Defaults to 0 (no wait period).
"""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@@ -96,7 +96,9 @@ class WorkerProfiler(ABC):
logger.info_once("Starting profiler after delay...", scope="local")
self._call_start()
if self._running:
# Call profiler step for schedule-based profiling
# Only count iterations where data is actually recorded (not warmup)
if self._running and self._profiler_step():
self._profiling_for_iters += 1
if (
@@ -113,6 +115,16 @@ class WorkerProfiler(ABC):
self._call_stop()
return
def _profiler_step(self) -> bool:
"""Called each step when profiler is running.
Override in subclasses to handle schedule-based profiling.
Returns:
True if the step was an active profiling step (data recorded),
False if the step was a warmup step (data discarded).
"""
return True
def stop(self) -> None:
"""Attempt to stop the profiler, accounting for overlapped calls."""
if not self._active:
@@ -187,8 +199,29 @@ class TorchProfilerWrapper(WorkerProfiler):
)
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
# Create profiler schedule if warmup or wait iterations are configured
profiler_schedule = None
if profiler_config.warmup_iterations > 0 or profiler_config.wait_iterations > 0:
profiler_schedule = torch.profiler.schedule(
skip_first=0,
wait=profiler_config.wait_iterations,
warmup=profiler_config.warmup_iterations,
active=profiler_config.active_iterations,
repeat=1,
)
if local_rank in (None, 0):
logger.info_once(
"Profiler schedule configured: wait=%d, warmup=%d, active=%d",
profiler_config.wait_iterations,
profiler_config.warmup_iterations,
profiler_config.active_iterations,
scope="local",
)
self.profiler = torch.profiler.profile(
activities=[TorchProfilerActivityMap[activity] for activity in activities],
schedule=profiler_schedule,
record_shapes=profiler_config.torch_profiler_record_shapes,
profile_memory=profiler_config.torch_profiler_with_memory,
with_stack=profiler_config.torch_profiler_with_stack,
@@ -196,6 +229,17 @@ class TorchProfilerWrapper(WorkerProfiler):
on_trace_ready=trace_handler,
)
# Track if we're using a schedule (need to call step())
self._uses_schedule = profiler_schedule is not None
self._warmup_iterations = profiler_config.warmup_iterations
# Subtract 1 because profiler.start() already consumes step 0
# (WAIT or WARMUP), so only wait + warmup - 1 non-active steps
# remain to be advanced through via profiler.step() calls.
self._warmup_steps_remaining = max(
profiler_config.wait_iterations + profiler_config.warmup_iterations - 1,
0,
)
@override
def _start(self) -> None:
self.profiler.start()
@@ -228,6 +272,22 @@ class TorchProfilerWrapper(WorkerProfiler):
)
)
@override
def _profiler_step(self) -> bool:
"""Call profiler.step() when using schedule-based profiling.
Returns:
True if the step was an active profiling step (data recorded),
False if the step was a warmup step (data discarded).
"""
if self._uses_schedule:
self.profiler.step()
# Track warmup steps - only count active steps toward max_iterations
if self._warmup_steps_remaining > 0:
self._warmup_steps_remaining -= 1
return False
return True
@override
def annotate_context_manager(self, name: str):
return torch.profiler.record_function(name)