Support custom URI schemes and trace handlers for profiler (#32393)

This commit is contained in:
David Ramon Prados
2026-01-22 12:45:40 -05:00
committed by GitHub
parent 803e3f3f68
commit 3a63be0faa
3 changed files with 74 additions and 19 deletions

View File

@@ -3,6 +3,7 @@
import pytest
from vllm.config import ProfilerConfig
from vllm.config.profiler import _is_uri_path
from vllm.profiler.wrapper import WorkerProfiler
@@ -202,3 +203,36 @@ def test_mixed_delay_and_stop(default_profiler_config):
profiler.step()
assert profiler.start_call_count == 0
class TestIsUriPath:
"""Tests for the _is_uri_path helper function."""
@pytest.mark.parametrize(
"path,expected",
[
# Valid URI schemes - should return True
("gs://bucket/path", True),
("s3://bucket/path", True),
("hdfs://cluster/path", True),
("abfs://container/path", True),
("http://example.com/path", True),
("https://example.com/path", True),
# Local paths - should return False
("/tmp/local/path", False),
("./relative/path", False),
("relative/path", False),
("/absolute/path", False),
# Windows drive letters - should return False (single char scheme)
("C://windows/path", False),
("D://drive/path", False),
# Edge cases
("", False),
("no-scheme", False),
("scheme-no-slashes:", False),
("://no-scheme", False),
],
)
def test_is_uri_path(self, path, expected):
"""Test that _is_uri_path correctly identifies URI vs local paths."""
assert _is_uri_path(path) == expected

View File

@@ -18,6 +18,20 @@ logger = init_logger(__name__)
ProfilerKind = Literal["torch", "cuda"]
def _is_uri_path(path: str) -> bool:
"""Check if path is a URI (scheme://...), excluding Windows drive letters.
Supports custom URI schemes like gs://, s3://, hdfs://, etc.
These paths should not be converted to absolute paths.
"""
if "://" in path:
scheme = path.split("://")[0]
# Windows drive letters are single characters (e.g., C://)
# Valid URI schemes have more than one character
return len(scheme) > 1
return False
@config
@dataclass
class ProfilerConfig:
@@ -54,7 +68,7 @@ class ProfilerConfig:
Disabled by default."""
ignore_frontend: bool = False
"""If `True`, disables the front-end profiling of AsyncLLM when using the
"""If `True`, disables the front-end profiling of AsyncLLM when using the
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
since the front-end profiling does not track iterations and will capture the
entire range.
@@ -185,15 +199,9 @@ class ProfilerConfig:
if self.profiler == "torch" and not profiler_dir:
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
if profiler_dir:
is_gs_path = (
profiler_dir.startswith("gs://")
and profiler_dir[5:]
and profiler_dir[5] != "/"
)
if not is_gs_path:
self.torch_profiler_dir = os.path.abspath(
os.path.expanduser(profiler_dir)
)
# Support any URI scheme (gs://, s3://, hdfs://, etc.)
# These paths should not be converted to absolute paths
if profiler_dir and not _is_uri_path(profiler_dir):
self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
return self

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from contextlib import nullcontext
from typing import Literal
@@ -9,6 +10,7 @@ import torch
from typing_extensions import override
from vllm.config import ProfilerConfig
from vllm.config.profiler import _is_uri_path
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -151,6 +153,7 @@ class TorchProfilerWrapper(WorkerProfiler):
worker_name: str,
local_rank: int,
activities: list[TorchProfilerActivity],
on_trace_ready: Callable[[torch.profiler.profile], None] | None = None,
) -> None:
super().__init__(profiler_config)
@@ -172,6 +175,17 @@ class TorchProfilerWrapper(WorkerProfiler):
profiler_config.torch_profiler_with_flops,
)
# Determine trace handler: use custom handler if provided,
# otherwise default to tensorboard trace handler
if on_trace_ready is not None:
trace_handler = on_trace_ready
else:
trace_handler = torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=profiler_config.torch_profiler_use_gzip,
)
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
self.profiler = torch.profiler.profile(
activities=[TorchProfilerActivityMap[activity] for activity in activities],
@@ -179,11 +193,7 @@ class TorchProfilerWrapper(WorkerProfiler):
profile_memory=profiler_config.torch_profiler_with_memory,
with_stack=profiler_config.torch_profiler_with_stack,
with_flops=profiler_config.torch_profiler_with_flops,
on_trace_ready=torch.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir,
worker_name=worker_name,
use_gzip=profiler_config.torch_profiler_use_gzip,
),
on_trace_ready=trace_handler,
)
@override
@@ -198,12 +208,15 @@ class TorchProfilerWrapper(WorkerProfiler):
rank = self.local_rank
if profiler_config.torch_profiler_dump_cuda_time_total:
profiler_dir = profiler_config.torch_profiler_dir
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
sort_key = "self_cuda_time_total"
table = self.profiler.key_averages().table(sort_by=sort_key)
with open(profiler_out_file, "w") as f:
print(table, file=f)
# Skip file write for URI paths (gs://, s3://, etc.)
# as standard file I/O doesn't work with URI schemes
if not _is_uri_path(profiler_dir):
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
with open(profiler_out_file, "w") as f:
print(table, file=f)
# only print profiler results on rank 0
if rank == 0: