diff --git a/tests/v1/worker/test_gpu_profiler.py b/tests/v1/worker/test_gpu_profiler.py index 933ea42f1..ca22f3c9d 100644 --- a/tests/v1/worker/test_gpu_profiler.py +++ b/tests/v1/worker/test_gpu_profiler.py @@ -3,6 +3,7 @@ import pytest from vllm.config import ProfilerConfig +from vllm.config.profiler import _is_uri_path from vllm.profiler.wrapper import WorkerProfiler @@ -202,3 +203,36 @@ def test_mixed_delay_and_stop(default_profiler_config): profiler.step() assert profiler.start_call_count == 0 + + +class TestIsUriPath: + """Tests for the _is_uri_path helper function.""" + + @pytest.mark.parametrize( + "path,expected", + [ + # Valid URI schemes - should return True + ("gs://bucket/path", True), + ("s3://bucket/path", True), + ("hdfs://cluster/path", True), + ("abfs://container/path", True), + ("http://example.com/path", True), + ("https://example.com/path", True), + # Local paths - should return False + ("/tmp/local/path", False), + ("./relative/path", False), + ("relative/path", False), + ("/absolute/path", False), + # Windows drive letters - should return False (single char scheme) + ("C://windows/path", False), + ("D://drive/path", False), + # Edge cases + ("", False), + ("no-scheme", False), + ("scheme-no-slashes:", False), + ("://no-scheme", False), + ], + ) + def test_is_uri_path(self, path, expected): + """Test that _is_uri_path correctly identifies URI vs local paths.""" + assert _is_uri_path(path) == expected diff --git a/vllm/config/profiler.py b/vllm/config/profiler.py index 76cc546f3..b35d88269 100644 --- a/vllm/config/profiler.py +++ b/vllm/config/profiler.py @@ -18,6 +18,20 @@ logger = init_logger(__name__) ProfilerKind = Literal["torch", "cuda"] +def _is_uri_path(path: str) -> bool: + """Check if path is a URI (scheme://...), excluding Windows drive letters. + + Supports custom URI schemes like gs://, s3://, hdfs://, etc. + These paths should not be converted to absolute paths. + """ + if "://" in path: + scheme = path.split("://")[0] + # Windows drive letters are single characters (e.g., C://) + # Valid URI schemes have more than one character + return len(scheme) > 1 + return False + + @config @dataclass class ProfilerConfig: @@ -54,7 +68,7 @@ class ProfilerConfig: Disabled by default.""" ignore_frontend: bool = False - """If `True`, disables the front-end profiling of AsyncLLM when using the + """If `True`, disables the front-end profiling of AsyncLLM when using the 'torch' profiler. This is needed to reduce overhead when using delay/limit options, since the front-end profiling does not track iterations and will capture the entire range. @@ -185,15 +199,9 @@ class ProfilerConfig: if self.profiler == "torch" and not profiler_dir: raise ValueError("torch_profiler_dir must be set when profiler is 'torch'") - if profiler_dir: - is_gs_path = ( - profiler_dir.startswith("gs://") - and profiler_dir[5:] - and profiler_dir[5] != "/" - ) - if not is_gs_path: - self.torch_profiler_dir = os.path.abspath( - os.path.expanduser(profiler_dir) - ) + # Support any URI scheme (gs://, s3://, hdfs://, etc.) + # These paths should not be converted to absolute paths + if profiler_dir and not _is_uri_path(profiler_dir): + self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir)) return self diff --git a/vllm/profiler/wrapper.py b/vllm/profiler/wrapper.py index f891a88f9..45aa88eef 100644 --- a/vllm/profiler/wrapper.py +++ b/vllm/profiler/wrapper.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from abc import ABC, abstractmethod +from collections.abc import Callable from contextlib import nullcontext from typing import Literal @@ -9,6 +10,7 @@ import torch from typing_extensions import override from vllm.config import ProfilerConfig +from vllm.config.profiler import _is_uri_path from vllm.logger import init_logger logger = init_logger(__name__) @@ -151,6 +153,7 @@ class TorchProfilerWrapper(WorkerProfiler): worker_name: str, local_rank: int, activities: list[TorchProfilerActivity], + on_trace_ready: Callable[[torch.profiler.profile], None] | None = None, ) -> None: super().__init__(profiler_config) @@ -172,6 +175,17 @@ class TorchProfilerWrapper(WorkerProfiler): profiler_config.torch_profiler_with_flops, ) + # Determine trace handler: use custom handler if provided, + # otherwise default to tensorboard trace handler + if on_trace_ready is not None: + trace_handler = on_trace_ready + else: + trace_handler = torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, + worker_name=worker_name, + use_gzip=profiler_config.torch_profiler_use_gzip, + ) + self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1 self.profiler = torch.profiler.profile( activities=[TorchProfilerActivityMap[activity] for activity in activities], @@ -179,11 +193,7 @@ class TorchProfilerWrapper(WorkerProfiler): profile_memory=profiler_config.torch_profiler_with_memory, with_stack=profiler_config.torch_profiler_with_stack, with_flops=profiler_config.torch_profiler_with_flops, - on_trace_ready=torch.profiler.tensorboard_trace_handler( - torch_profiler_trace_dir, - worker_name=worker_name, - use_gzip=profiler_config.torch_profiler_use_gzip, - ), + on_trace_ready=trace_handler, ) @override @@ -198,12 +208,15 @@ class TorchProfilerWrapper(WorkerProfiler): rank = self.local_rank if profiler_config.torch_profiler_dump_cuda_time_total: profiler_dir = profiler_config.torch_profiler_dir - profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" sort_key = "self_cuda_time_total" table = self.profiler.key_averages().table(sort_by=sort_key) - with open(profiler_out_file, "w") as f: - print(table, file=f) + # Skip file write for URI paths (gs://, s3://, etc.) + # as standard file I/O doesn't work with URI schemes + if not _is_uri_path(profiler_dir): + profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt" + with open(profiler_out_file, "w") as f: + print(table, file=f) # only print profiler results on rank 0 if rank == 0: