Support custom URI schemes and trace handlers for profiler (#32393)

This commit is contained in:
David Ramon Prados
2026-01-22 12:45:40 -05:00
committed by GitHub
parent 803e3f3f68
commit 3a63be0faa
3 changed files with 74 additions and 19 deletions

View File

@@ -18,6 +18,20 @@ logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
def _is_uri_path(path: str) -> bool:
"""Check if path is a URI (scheme://...), excluding Windows drive letters.
Supports custom URI schemes like gs://, s3://, hdfs://, etc.
These paths should not be converted to absolute paths.
"""
if "://" in path:
scheme = path.split("://")[0]
# Windows drive letters are single characters (e.g., C://)
# Valid URI schemes have more than one character
return len(scheme) > 1
return False
@config
@dataclass
class ProfilerConfig:
@@ -54,7 +68,7 @@ class ProfilerConfig:
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
@@ -185,15 +199,9 @@ class ProfilerConfig:
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
# Support any URI scheme (gs://, s3://, hdfs://, etc.)
# These paths should not be converted to absolute paths
if profiler_dir and not _is_uri_path(profiler_dir):
self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
return self

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import nullcontext
from typing import Literal
@@ -9,6 +10,7 @@ import torch
from typing_extensions import override
from vllm.config import ProfilerConfig
from vllm.config.profiler import _is_uri_path
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -151,6 +153,7 @@ class TorchProfilerWrapper(WorkerProfiler):
worker_name: str,
local_rank: int,
activities: list[TorchProfilerActivity],
on_trace_ready: Callable[[torch.profiler.profile], None] | None = None,
) -> None:
super().__init__(profiler_config)
@@ -172,6 +175,17 @@ class TorchProfilerWrapper(WorkerProfiler):
profiler_config.torch_profiler_with_flops,
)
# Determine trace handler: use custom handler if provided,
# otherwise default to tensorboard trace handler
if on_trace_ready is not None:
trace_handler = on_trace_ready
else:
trace_handler = torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=profiler_config.torch_profiler_use_gzip,
)
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
self.profiler = torch.profiler.profile(
activities=[TorchProfilerActivityMap[activity] for activity in activities],
@@ -179,11 +193,7 @@ class TorchProfilerWrapper(WorkerProfiler):
profile_memory=profiler_config.torch_profiler_with_memory,
with_stack=profiler_config.torch_profiler_with_stack,
with_flops=profiler_config.torch_profiler_with_flops,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=profiler_config.torch_profiler_use_gzip,
),
on_trace_ready=trace_handler,
)
@override
@@ -198,12 +208,15 @@ class TorchProfilerWrapper(WorkerProfiler):
rank = self.local_rank
if profiler_config.torch_profiler_dump_cuda_time_total:
profiler_dir = profiler_config.torch_profiler_dir
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
sort_key = "self_cuda_time_total"
table = self.profiler.key_averages().table(sort_by=sort_key)
with open(profiler_out_file, "w") as f:
print(table, file=f)
# Skip file write for URI paths (gs://, s3://, etc.)
# as standard file I/O doesn't work with URI schemes
if not _is_uri_path(profiler_dir):
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
with open(profiler_out_file, "w") as f:
print(table, file=f)
# only print profiler results on rank 0
if rank == 0: