Add PyTorch profiler schedule support with warmup/active iterations (#35240)
This commit is contained in:
@@ -45,8 +45,10 @@ class ProfilerConfig:
|
||||
worker's traces (CPU & GPU) will be saved under this directory. Note that
|
||||
it must be an absolute path."""
|
||||
|
||||
torch_profiler_with_stack: bool = True
|
||||
"""If `True`, enables stack tracing in the torch profiler. Enabled by default."""
|
||||
torch_profiler_with_stack: bool = False
|
||||
"""If `True`, enables stack tracing in the torch profiler. Disabled by default
|
||||
to reduce overhead. Can be enabled via VLLM_TORCH_PROFILER_WITH_STACK=1 env var
|
||||
or --profiler-config.torch_profiler_with_stack=true CLI flag."""
|
||||
|
||||
torch_profiler_with_flops: bool = False
|
||||
"""If `True`, enables FLOPS counting in the torch profiler. Disabled by default."""
|
||||
@@ -81,6 +83,27 @@ class ProfilerConfig:
|
||||
Defaults to 0, meaning no limit.
|
||||
"""
|
||||
|
||||
warmup_iterations: int = Field(default=0, ge=0)
|
||||
"""Number of warmup iterations for PyTorch profiler schedule.
|
||||
During warmup, the profiler runs but data is discarded. This helps reduce
|
||||
noise from JIT compilation and other one-time costs in the profiled trace.
|
||||
Defaults to 0 (schedule-based profiling disabled, recording all iterations).
|
||||
Set to a positive value (e.g., 2) to enable schedule-based profiling.
|
||||
"""
|
||||
|
||||
active_iterations: int = Field(default=5, ge=1)
|
||||
"""Number of active iterations for PyTorch profiler schedule.
|
||||
This is the number of iterations where profiling data is actually collected.
|
||||
Defaults to 5 active iterations.
|
||||
"""
|
||||
|
||||
wait_iterations: int = Field(default=0, ge=0)
|
||||
"""Number of wait iterations for PyTorch profiler schedule.
|
||||
During wait, the profiler is completely off with zero overhead.
|
||||
This allows skipping initial iterations before warmup begins.
|
||||
Defaults to 0 (no wait period).
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@@ -96,7 +96,9 @@ class WorkerProfiler(ABC):
|
||||
logger.info_once("Starting profiler after delay...", scope="local")
|
||||
self._call_start()
|
||||
|
||||
if self._running:
|
||||
# Call profiler step for schedule-based profiling
|
||||
# Only count iterations where data is actually recorded (not warmup)
|
||||
if self._running and self._profiler_step():
|
||||
self._profiling_for_iters += 1
|
||||
|
||||
if (
|
||||
@@ -113,6 +115,16 @@ class WorkerProfiler(ABC):
|
||||
self._call_stop()
|
||||
return
|
||||
|
||||
def _profiler_step(self) -> bool:
|
||||
"""Called each step when profiler is running.
|
||||
Override in subclasses to handle schedule-based profiling.
|
||||
|
||||
Returns:
|
||||
True if the step was an active profiling step (data recorded),
|
||||
False if the step was a warmup step (data discarded).
|
||||
"""
|
||||
return True
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Attempt to stop the profiler, accounting for overlapped calls."""
|
||||
if not self._active:
|
||||
@@ -187,8 +199,29 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
)
|
||||
|
||||
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
|
||||
|
||||
# Create profiler schedule if warmup or wait iterations are configured
|
||||
profiler_schedule = None
|
||||
if profiler_config.warmup_iterations > 0 or profiler_config.wait_iterations > 0:
|
||||
profiler_schedule = torch.profiler.schedule(
|
||||
skip_first=0,
|
||||
wait=profiler_config.wait_iterations,
|
||||
warmup=profiler_config.warmup_iterations,
|
||||
active=profiler_config.active_iterations,
|
||||
repeat=1,
|
||||
)
|
||||
if local_rank in (None, 0):
|
||||
logger.info_once(
|
||||
"Profiler schedule configured: wait=%d, warmup=%d, active=%d",
|
||||
profiler_config.wait_iterations,
|
||||
profiler_config.warmup_iterations,
|
||||
profiler_config.active_iterations,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[TorchProfilerActivityMap[activity] for activity in activities],
|
||||
schedule=profiler_schedule,
|
||||
record_shapes=profiler_config.torch_profiler_record_shapes,
|
||||
profile_memory=profiler_config.torch_profiler_with_memory,
|
||||
with_stack=profiler_config.torch_profiler_with_stack,
|
||||
@@ -196,6 +229,17 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
on_trace_ready=trace_handler,
|
||||
)
|
||||
|
||||
# Track if we're using a schedule (need to call step())
|
||||
self._uses_schedule = profiler_schedule is not None
|
||||
self._warmup_iterations = profiler_config.warmup_iterations
|
||||
# Subtract 1 because profiler.start() already consumes step 0
|
||||
# (WAIT or WARMUP), so only wait + warmup - 1 non-active steps
|
||||
# remain to be advanced through via profiler.step() calls.
|
||||
self._warmup_steps_remaining = max(
|
||||
profiler_config.wait_iterations + profiler_config.warmup_iterations - 1,
|
||||
0,
|
||||
)
|
||||
|
||||
@override
|
||||
def _start(self) -> None:
|
||||
self.profiler.start()
|
||||
@@ -228,6 +272,22 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
)
|
||||
)
|
||||
|
||||
@override
|
||||
def _profiler_step(self) -> bool:
|
||||
"""Call profiler.step() when using schedule-based profiling.
|
||||
|
||||
Returns:
|
||||
True if the step was an active profiling step (data recorded),
|
||||
False if the step was a warmup step (data discarded).
|
||||
"""
|
||||
if self._uses_schedule:
|
||||
self.profiler.step()
|
||||
# Track warmup steps - only count active steps toward max_iterations
|
||||
if self._warmup_steps_remaining > 0:
|
||||
self._warmup_steps_remaining -= 1
|
||||
return False
|
||||
return True
|
||||
|
||||
@override
|
||||
def annotate_context_manager(self, name: str):
|
||||
return torch.profiler.record_function(name)
|
||||
|
||||
Reference in New Issue
Block a user