[BugFix] Fix multiple/duplicate stdout prefixes (#36822)
Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -21,7 +21,6 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.system_utils import decorate_logs, set_process_title
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor import Executor
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
@@ -210,7 +209,6 @@ def run_headless(args: argparse.Namespace):
|
||||
|
||||
# Create the engines.
|
||||
engine_manager = CoreEngineProcManager(
|
||||
target_fn=EngineCoreProc.run_engine_core,
|
||||
local_engine_count=local_engine_count,
|
||||
start_index=vllm_config.parallel_config.data_parallel_rank,
|
||||
local_start_index=0,
|
||||
|
||||
@@ -204,7 +204,8 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
prefix = f"({worker_name} pid={pid}) "
|
||||
else:
|
||||
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
||||
file_write = file.write
|
||||
# Use the original write to avoid nesting prefixes on repeated calls.
|
||||
file_write = getattr(file, "_original_write", file.write)
|
||||
|
||||
def write_with_prefix(s: str):
|
||||
if not s:
|
||||
@@ -224,6 +225,7 @@ def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||
file.start_new_line = False # type: ignore[attr-defined]
|
||||
|
||||
file.start_new_line = True # type: ignore[attr-defined]
|
||||
file._original_write = file_write # type: ignore[attr-defined]
|
||||
file.write = write_with_prefix # type: ignore[method-assign]
|
||||
|
||||
|
||||
|
||||
@@ -1045,19 +1045,11 @@ class EngineCoreProc(EngineCore):
|
||||
data_parallel = parallel_config.data_parallel_size > 1 or dp_rank > 0
|
||||
if data_parallel:
|
||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||
maybe_init_worker_tracer(
|
||||
instrumenting_module_name="vllm.engine_core",
|
||||
process_kind="engine_core",
|
||||
process_name=f"EngineCore_DP{dp_rank}",
|
||||
)
|
||||
set_process_title("EngineCore", f"DP{dp_rank}")
|
||||
process_title = f"EngineCore_DP{dp_rank}"
|
||||
else:
|
||||
maybe_init_worker_tracer(
|
||||
instrumenting_module_name="vllm.engine_core",
|
||||
process_kind="engine_core",
|
||||
process_name="EngineCore",
|
||||
)
|
||||
set_process_title("EngineCore")
|
||||
process_title = "EngineCore"
|
||||
set_process_title(process_title)
|
||||
maybe_init_worker_tracer("vllm.engine_core", "engine_core", process_title)
|
||||
decorate_logs()
|
||||
|
||||
if data_parallel and vllm_config.kv_transfer_config is not None:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import contextlib
|
||||
import os
|
||||
import weakref
|
||||
from collections.abc import Callable, Iterator
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import Process, connection
|
||||
@@ -85,7 +85,6 @@ class CoreEngineProcManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target_fn: Callable,
|
||||
local_engine_count: int,
|
||||
start_index: int,
|
||||
local_start_index: int,
|
||||
@@ -108,6 +107,10 @@ class CoreEngineProcManager:
|
||||
if client_handshake_address:
|
||||
common_kwargs["client_handshake_address"] = client_handshake_address
|
||||
|
||||
is_dp = vllm_config.parallel_config.data_parallel_size > 1
|
||||
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
|
||||
self.processes: list[BaseProcess] = []
|
||||
local_dp_ranks = []
|
||||
for index in range(local_engine_count):
|
||||
@@ -118,35 +121,27 @@ class CoreEngineProcManager:
|
||||
local_dp_ranks.append(local_index)
|
||||
self.processes.append(
|
||||
context.Process(
|
||||
target=target_fn,
|
||||
name=f"EngineCore_DP{global_index}",
|
||||
target=EngineCoreProc.run_engine_core,
|
||||
name=f"EngineCore_DP{global_index}" if is_dp else "EngineCore",
|
||||
kwargs=common_kwargs
|
||||
| {
|
||||
"dp_rank": global_index,
|
||||
"local_dp_rank": local_index,
|
||||
},
|
||||
| {"dp_rank": global_index, "local_dp_rank": local_index},
|
||||
)
|
||||
)
|
||||
|
||||
self._finalizer = weakref.finalize(self, shutdown, self.processes)
|
||||
|
||||
data_parallel = vllm_config.parallel_config.data_parallel_size > 1
|
||||
try:
|
||||
for proc, local_dp_rank in zip(self.processes, local_dp_ranks):
|
||||
# Adjust device control in DP for non-CUDA platforms
|
||||
# as well as external and ray launchers
|
||||
# For CUDA platforms, we use torch.cuda.set_device()
|
||||
with (
|
||||
set_device_control_env_var(vllm_config, local_dp_rank)
|
||||
if (
|
||||
data_parallel
|
||||
and (
|
||||
not current_platform.is_cuda_alike()
|
||||
or vllm_config.parallel_config.use_ray
|
||||
)
|
||||
)
|
||||
else contextlib.nullcontext()
|
||||
if is_dp and (
|
||||
not current_platform.is_cuda_alike()
|
||||
or vllm_config.parallel_config.use_ray
|
||||
):
|
||||
with set_device_control_env_var(vllm_config, local_dp_rank):
|
||||
proc.start()
|
||||
else:
|
||||
proc.start()
|
||||
finally:
|
||||
# Kill other procs if not all are running.
|
||||
@@ -926,12 +921,9 @@ def launch_core_engines(
|
||||
with zmq_socket_ctx(
|
||||
local_handshake_address, zmq.ROUTER, bind=True
|
||||
) as handshake_socket:
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
|
||||
# Start local engines.
|
||||
if local_engine_count:
|
||||
local_engine_manager = CoreEngineProcManager(
|
||||
EngineCoreProc.run_engine_core,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=log_stats,
|
||||
|
||||
Reference in New Issue
Block a user