Support custom URI schemes and trace handlers for profiler (#32393)
This commit is contained in:
committed by
GitHub
parent
803e3f3f68
commit
3a63be0faa
@@ -18,6 +18,20 @@ logger = init_logger(__name__)
|
||||
ProfilerKind = Literal["torch", "cuda"]
|
||||
|
||||
|
||||
def _is_uri_path(path: str) -> bool:
|
||||
"""Check if path is a URI (scheme://...), excluding Windows drive letters.
|
||||
|
||||
Supports custom URI schemes like gs://, s3://, hdfs://, etc.
|
||||
These paths should not be converted to absolute paths.
|
||||
"""
|
||||
if "://" in path:
|
||||
scheme = path.split("://")[0]
|
||||
# Windows drive letters are single characters (e.g., C://)
|
||||
# Valid URI schemes have more than one character
|
||||
return len(scheme) > 1
|
||||
return False
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ProfilerConfig:
|
||||
@@ -54,7 +68,7 @@ class ProfilerConfig:
|
||||
Disabled by default."""
|
||||
|
||||
ignore_frontend: bool = False
|
||||
"""If `True`, disables the front-end profiling of AsyncLLM when using the
|
||||
"""If `True`, disables the front-end profiling of AsyncLLM when using the
|
||||
'torch' profiler. This is needed to reduce overhead when using delay/limit options,
|
||||
since the front-end profiling does not track iterations and will capture the
|
||||
entire range.
|
||||
@@ -185,15 +199,9 @@ class ProfilerConfig:
|
||||
if self.profiler == "torch" and not profiler_dir:
|
||||
raise ValueError("torch_profiler_dir must be set when profiler is 'torch'")
|
||||
|
||||
if profiler_dir:
|
||||
is_gs_path = (
|
||||
profiler_dir.startswith("gs://")
|
||||
and profiler_dir[5:]
|
||||
and profiler_dir[5] != "/"
|
||||
)
|
||||
if not is_gs_path:
|
||||
self.torch_profiler_dir = os.path.abspath(
|
||||
os.path.expanduser(profiler_dir)
|
||||
)
|
||||
# Support any URI scheme (gs://, s3://, hdfs://, etc.)
|
||||
# These paths should not be converted to absolute paths
|
||||
if profiler_dir and not _is_uri_path(profiler_dir):
|
||||
self.torch_profiler_dir = os.path.abspath(os.path.expanduser(profiler_dir))
|
||||
|
||||
return self
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import nullcontext
|
||||
from typing import Literal
|
||||
|
||||
@@ -9,6 +10,7 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from vllm.config import ProfilerConfig
|
||||
from vllm.config.profiler import _is_uri_path
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -151,6 +153,7 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
worker_name: str,
|
||||
local_rank: int,
|
||||
activities: list[TorchProfilerActivity],
|
||||
on_trace_ready: Callable[[torch.profiler.profile], None] | None = None,
|
||||
) -> None:
|
||||
super().__init__(profiler_config)
|
||||
|
||||
@@ -172,6 +175,17 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
profiler_config.torch_profiler_with_flops,
|
||||
)
|
||||
|
||||
# Determine trace handler: use custom handler if provided,
|
||||
# otherwise default to tensorboard trace handler
|
||||
if on_trace_ready is not None:
|
||||
trace_handler = on_trace_ready
|
||||
else:
|
||||
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir,
|
||||
worker_name=worker_name,
|
||||
use_gzip=profiler_config.torch_profiler_use_gzip,
|
||||
)
|
||||
|
||||
self.dump_cpu_time_total = "CPU" in activities and len(activities) == 1
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[TorchProfilerActivityMap[activity] for activity in activities],
|
||||
@@ -179,11 +193,7 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
profile_memory=profiler_config.torch_profiler_with_memory,
|
||||
with_stack=profiler_config.torch_profiler_with_stack,
|
||||
with_flops=profiler_config.torch_profiler_with_flops,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir,
|
||||
worker_name=worker_name,
|
||||
use_gzip=profiler_config.torch_profiler_use_gzip,
|
||||
),
|
||||
on_trace_ready=trace_handler,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -198,12 +208,15 @@ class TorchProfilerWrapper(WorkerProfiler):
|
||||
rank = self.local_rank
|
||||
if profiler_config.torch_profiler_dump_cuda_time_total:
|
||||
profiler_dir = profiler_config.torch_profiler_dir
|
||||
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
|
||||
sort_key = "self_cuda_time_total"
|
||||
table = self.profiler.key_averages().table(sort_by=sort_key)
|
||||
|
||||
with open(profiler_out_file, "w") as f:
|
||||
print(table, file=f)
|
||||
# Skip file write for URI paths (gs://, s3://, etc.)
|
||||
# as standard file I/O doesn't work with URI schemes
|
||||
if not _is_uri_path(profiler_dir):
|
||||
profiler_out_file = f"{profiler_dir}/profiler_out_{rank}.txt"
|
||||
with open(profiler_out_file, "w") as f:
|
||||
print(table, file=f)
|
||||
|
||||
# only print profiler results on rank 0
|
||||
if rank == 0:
|
||||
|
||||
Reference in New Issue
Block a user