diff --git a/vllm/distributed/utils.py b/vllm/distributed/utils.py index 8df9d638a..17375259e 100644 --- a/vllm/distributed/utils.py +++ b/vllm/distributed/utils.py @@ -524,3 +524,43 @@ def stateless_destroy_torch_distributed_process_group(pg: ProcessGroup) -> None: """ pg.shutdown() _unregister_process_group(pg.group_name) + + +def get_worker_rank_suffix(global_rank: int | None = None) -> str: + """Generate a descriptive rank suffix for worker identification. + + Returns a string like 'dp0_pp0_tp0_dcp0_ep0_rank0' including all + parallel dimensions: DP, PP, TP, DCP, EP. + + Args: + global_rank: Optional global rank to append. If not provided, + only parallel dimension ranks are included. + + Returns: + A string suffix identifying the worker's position in the + distributed topology. + """ + from vllm.distributed.parallel_state import ( + get_dcp_group, + get_dp_group, + get_ep_group, + get_pp_group, + get_tp_group, + ) + + try: + dp_rank = get_dp_group().rank_in_group + pp_rank = get_pp_group().rank_in_group + tp_rank = get_tp_group().rank_in_group + dcp_rank = get_dcp_group().rank_in_group + ep_rank = get_ep_group().rank_in_group + + suffix = f"dp{dp_rank}_pp{pp_rank}_tp{tp_rank}_dcp{dcp_rank}_ep{ep_rank}" + if global_rank is not None: + suffix = f"{suffix}_rank{global_rank}" + return suffix + except Exception: + # Fallback if parallel state not initialized + if global_rank is not None: + return f"rank{global_rank}" + return "" diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9cb40448b..f54d9121c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1685,8 +1685,15 @@ class LLM: tokenization_kwargs=encode_kwargs, ) - def start_profile(self) -> None: - self.llm_engine.start_profile() + def start_profile(self, profile_prefix: str | None = None) -> None: + """Start profiling with optional custom trace prefix. + + Args: + profile_prefix: Optional prefix for the trace file names. If provided, + trace files will be named as "_dp_pp_tp". + If not provided, default naming will be used. + """ + self.llm_engine.start_profile(profile_prefix) def stop_profile(self) -> None: self.llm_engine.stop_profile() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 44853ec88..bab898da6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -911,8 +911,8 @@ class AsyncLLM(EngineClient): if self.errored: raise self.dead_error - async def start_profile(self) -> None: - coros = [self.engine_core.profile_async(True)] + async def start_profile(self, profile_prefix: str | None = None) -> None: + coros = [self.engine_core.profile_async(True, profile_prefix)] if self.profiler is not None: coros.append(asyncio.to_thread(self.profiler.start)) await asyncio.gather(*coros) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index afa59d52d..7553c7332 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -568,8 +568,8 @@ class EngineCore: if self.scheduler: self.scheduler.shutdown() - def profile(self, is_start: bool = True): - self.model_executor.profile(is_start) + def profile(self, is_start: bool = True, profile_prefix: str | None = None): + self.model_executor.profile(is_start, profile_prefix) def reset_mm_cache(self): # NOTE: Since this is mainly for debugging, we don't attempt to diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b31f1c406..e9187c4e8 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -135,7 +135,7 @@ class EngineCoreClient(ABC): def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError - def profile(self, is_start: bool = True) -> None: + def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: raise NotImplementedError def reset_mm_cache(self) -> None: @@ -210,7 +210,9 @@ class EngineCoreClient(ABC): async def add_request_async(self, request: EngineCoreRequest) -> None: raise NotImplementedError - async def profile_async(self, is_start: bool = True) -> None: + async def profile_async( + self, is_start: bool = True, profile_prefix: str | None = None + ) -> None: raise NotImplementedError async def reset_mm_cache_async(self) -> None: @@ -295,8 +297,8 @@ class InprocClient(EngineCoreClient): def shutdown(self) -> None: self.engine_core.shutdown() - def profile(self, is_start: bool = True) -> None: - self.engine_core.profile(is_start) + def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: + self.engine_core.profile(is_start, profile_prefix) def reset_mm_cache(self) -> None: self.engine_core.reset_mm_cache() @@ -765,8 +767,8 @@ class SyncMPClient(MPClient): if request_ids and not self.resources.engine_dead: self._send_input(EngineCoreRequestType.ABORT, request_ids) - def profile(self, is_start: bool = True) -> None: - self.call_utility("profile", is_start) + def profile(self, is_start: bool = True, profile_prefix: str | None = None) -> None: + self.call_utility("profile", is_start, profile_prefix) def reset_mm_cache(self) -> None: self.call_utility("reset_mm_cache") @@ -987,8 +989,10 @@ class AsyncMPClient(MPClient): """Resume the scheduler after a pause.""" await self.call_utility_async("resume_scheduler") - async def profile_async(self, is_start: bool = True) -> None: - await self.call_utility_async("profile", is_start) + async def profile_async( + self, is_start: bool = True, profile_prefix: str | None = None + ) -> None: + await self.call_utility_async("profile", is_start, profile_prefix) async def reset_mm_cache_async(self) -> None: await self.call_utility_async("reset_mm_cache") diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 51f39c929..76aa8f438 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -326,8 +326,8 @@ class LLMEngine: return processed_outputs.request_outputs - def start_profile(self): - self.engine_core.profile(True) + def start_profile(self, profile_prefix: str | None = None): + self.engine_core.profile(True, profile_prefix) def stop_profile(self): self.engine_core.profile(False) diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 32fa87e9d..91bd019f8 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -238,8 +238,8 @@ class Executor(ABC): def max_concurrent_batches(self) -> int: return 1 - def profile(self, is_start: bool = True): - self.collective_rpc("profile", args=(is_start,)) + def profile(self, is_start: bool = True, profile_prefix: str | None = None): + self.collective_rpc("profile", args=(is_start, profile_prefix)) def save_sharded_state( self, diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 49b97e8f3..229b5742d 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1305,8 +1305,8 @@ class StatLoggerManager: ): if engine_idx is None: engine_idx = 0 - for logger in self.stat_loggers: - logger.record( + for stat_logger in self.stat_loggers: + stat_logger.record( scheduler_stats, iteration_stats, mm_cache_stats=mm_cache_stats, diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 2fbcc9c44..752b692f8 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -212,7 +212,7 @@ class CPUWorker(Worker): ) return ",".join([str(x.id) for x in logical_cpu_list]) - def profile(self, is_start: bool = True): + def profile(self, is_start: bool = True, profile_prefix: str | None = None): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") if is_start: diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 635402f3d..2507b7f20 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -103,20 +103,14 @@ class Worker(WorkerBase): ) # Torch/CUDA profiler. Enabled and configured through profiler_config. + # Profiler wrapper is created lazily in profile() when start is called, + # so we have all the information needed for proper trace naming. self.profiler: Any | None = None - profiler_config = vllm_config.profiler_config - if profiler_config.profiler == "torch": - worker_name = f"{vllm_config.instance_id}-rank-{self.rank}" - self.profiler = TorchProfilerWrapper( - profiler_config, - worker_name=worker_name, - local_rank=self.local_rank, - activities=["CPU", "CUDA"], - ) - elif profiler_config.profiler == "cuda": - self.profiler = CudaProfilerWrapper(profiler_config) - else: - self.profiler = None + self.profiler_config = vllm_config.profiler_config + + # Only validate profiler config is valid, don't instantiate yet + if self.profiler_config.profiler not in ("torch", "cuda", None): + raise ValueError(f"Unknown profiler type: {self.profiler_config.profiler}") self.use_v2_model_runner = envs.VLLM_USE_V2_MODEL_RUNNER @@ -677,17 +671,52 @@ class Worker(WorkerBase): def take_draft_token_ids(self) -> DraftTokenIds | None: return self.model_runner.take_draft_token_ids() - def profile(self, is_start: bool = True): - if self.profiler is None: + def profile(self, is_start: bool = True, profile_prefix: str | None = None): + # Check if profiling is enabled + if self.profiler_config is None or self.profiler_config.profiler is None: raise RuntimeError( "Profiling is not enabled. Please set --profiler-config to enable " "profiling. Example: " "'--profiler-config.profiler=torch --profiler-config.torch_profiler_dir" "=YOUR_DIR_PATH_TO_DUMP_TRACE'" ) + if is_start: - self.profiler.start() + # Generate the trace name by combining prefix with comprehensive rank suffix + from vllm.distributed.utils import get_worker_rank_suffix + + rank_suffix = get_worker_rank_suffix(global_rank=self.rank) + + # Build the full trace name + if profile_prefix: + trace_name = f"{profile_prefix}_{rank_suffix}" + else: + trace_name = rank_suffix + + # Create the profiler wrapper only on the first start call + if self.profiler is None: + if self.profiler_config.profiler == "torch": + self.profiler = TorchProfilerWrapper( + self.profiler_config, + worker_name=trace_name, + local_rank=self.local_rank, + activities=["CPU", "CUDA"], + ) + logger.debug( + "Starting torch profiler with trace name: %s", trace_name + ) + elif self.profiler_config.profiler == "cuda": + self.profiler = CudaProfilerWrapper(self.profiler_config) + logger.debug("Starting CUDA profiler") + self.profiler.start() + else: + # Profiler already initialized. Restart profiling but keep + # the original trace name from the first initialization. + self.profiler.start() else: + if self.profiler is None: + logger.warning("Profiler was not started, nothing to stop.") + return self.profiler.stop() def execute_dummy_batch(self) -> None: