From 1ab8fc8197f5aa0ae471ef0ff1c725b733594f7a Mon Sep 17 00:00:00 2001 From: Yifei Zhang Date: Mon, 1 Dec 2025 12:30:46 +0800 Subject: [PATCH] Make PyTorch profiler gzip and CUDA time dump configurable (#29568) Signed-off-by: Yifei Zhang --- docs/contributing/profiling.md | 2 ++ vllm/envs.py | 13 +++++++++++++ vllm/profiler/gpu_profiler.py | 25 ++++++++++++++----------- vllm/v1/engine/async_llm.py | 4 +++- vllm/v1/worker/xpu_worker.py | 4 +++- 5 files changed, 35 insertions(+), 13 deletions(-) diff --git a/docs/contributing/profiling.md b/docs/contributing/profiling.md index 7634cc085..65382afbe 100644 --- a/docs/contributing/profiling.md +++ b/docs/contributing/profiling.md @@ -11,6 +11,8 @@ We support tracing vLLM workers using the `torch.profiler` module. You can enabl - `VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY=1` to record memory, off by default - `VLLM_TORCH_PROFILER_WITH_STACK=1` to enable recording stack information, on by default - `VLLM_TORCH_PROFILER_WITH_FLOPS=1` to enable recording FLOPs, off by default +- `VLLM_TORCH_PROFILER_USE_GZIP=0` to disable gzip-compressing profiling files, on by default +- `VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0` to disable dumping and printing the aggregated CUDA self time table, on by default The OpenAI server also needs to be started with the `VLLM_TORCH_PROFILER_DIR` environment variable set. diff --git a/vllm/envs.py b/vllm/envs.py index 541d5e20d..46f1aa322 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -100,6 +100,8 @@ if TYPE_CHECKING: VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False VLLM_PROFILER_DELAY_ITERS: int = 0 VLLM_PROFILER_MAX_ITERS: int = 0 + VLLM_TORCH_PROFILER_USE_GZIP: bool = True + VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: bool = True VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False @@ -890,6 +892,17 @@ environment_variables: dict[str, Callable[[], Any]] = { # Maximum number of iterations to profile when using the torch/torch CUDA profiler. # If set to 0, will not limit the number of iterations. "VLLM_PROFILER_MAX_ITERS": lambda: int(os.getenv("VLLM_PROFILER_MAX_ITERS", "0")), + # Control whether torch profiler gzip-compresses profiling files. + # Set VLLM_TORCH_PROFILER_USE_GZIP=0 to disable gzip (enabled by default). + "VLLM_TORCH_PROFILER_USE_GZIP": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_USE_GZIP", "1") != "0" + ), + # Control whether torch profiler dumps the self_cuda_time_total table. + # Set VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL=0 to disable dumping + # (enabled by default). + "VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL": lambda: bool( + os.getenv("VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL", "1") != "0" + ), # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), # If set, allow loading or unloading lora adapters in runtime, diff --git a/vllm/profiler/gpu_profiler.py b/vllm/profiler/gpu_profiler.py index 3e2cbe729..798c61522 100644 --- a/vllm/profiler/gpu_profiler.py +++ b/vllm/profiler/gpu_profiler.py @@ -162,7 +162,9 @@ class TorchProfilerWrapper(WorkerProfiler): with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True + torch_profiler_trace_dir, + worker_name=worker_name, + use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, ), ) @@ -174,18 +176,19 @@ class TorchProfilerWrapper(WorkerProfiler): def _stop(self) -> None: self.profiler.stop() - rank = self.local_rank - profiler_dir = envs.VLLM_TORCH_PROFILER_DIR - profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" - sort_key = "self_cuda_time_total" - table = self.profiler.key_averages().table(sort_by=sort_key) + if envs.VLLM_TORCH_PROFILER_DUMP_CUDA_TIME_TOTAL: + rank = self.local_rank + profiler_dir = envs.VLLM_TORCH_PROFILER_DIR + profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" + sort_key = "self_cuda_time_total" + table = self.profiler.key_averages().table(sort_by=sort_key) - with open(profiler_out_file, "w") as f: - print(table, file=f) + with open(profiler_out_file, "w") as f: + print(table, file=f) - # only print profiler results on rank 0 - if rank == 0: - print(table) + # only print profiler results on rank 0 + if rank == 0: + print(table) @override def annotate_context_manager(self, name: str): diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 336d3e9fa..d0708a8a0 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -190,7 +190,9 @@ class AsyncLLM(EngineClient): ], with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, on_trace_ready=torch.profiler.tensorboard_trace_handler( - envs.VLLM_TORCH_PROFILER_DIR, worker_name=worker_name, use_gzip=True + envs.VLLM_TORCH_PROFILER_DIR, + worker_name=worker_name, + use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, ), ) else: diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 4d7864e90..267369c73 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -64,7 +64,9 @@ class XPUWorker(Worker): with_stack=envs.VLLM_TORCH_PROFILER_WITH_STACK, with_flops=envs.VLLM_TORCH_PROFILER_WITH_FLOPS, on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, worker_name=worker_name, use_gzip=True + torch_profiler_trace_dir, + worker_name=worker_name, + use_gzip=envs.VLLM_TORCH_PROFILER_USE_GZIP, ), ) else: