break execute_model in gpu_model_runner into sub-functions for custom scopes (#24265)

Co-authored-by: Bangsheng Tang <bangsheng@meta.com>
This commit is contained in:
Bangsheng Tang
2025-09-06 14:02:47 -07:00
committed by GitHub
parent e68dc2f014
commit 848562bd49
3 changed files with 208 additions and 109 deletions

View File

@@ -1,17 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import multiprocessing
import time
import weakref
from collections.abc import Sequence
from contextlib import AbstractContextManager
from multiprocessing import connection
from multiprocessing.process import BaseProcess
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
Union, overload)
import torch
from torch.autograd.profiler import record_function
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
@@ -155,7 +159,7 @@ def get_engine_client_zmq_addr(local_only: bool,
class APIServerProcessManager:
"""Manages a group of API server processes.
Handles creation, monitoring, and termination of API server worker
processes. Also monitors extra processes to check if they are healthy.
"""
@@ -172,7 +176,7 @@ class APIServerProcessManager:
stats_update_address: Optional[str] = None,
):
"""Initialize and start API server worker processes.
Args:
target_server_fn: Function to call for each API server process
listen_address: Address to listen for client connections
@@ -181,7 +185,7 @@ class APIServerProcessManager:
num_servers: Number of API server processes to start
input_addresses: Input addresses for each API server
output_addresses: Output addresses for each API server
stats_update_address: Optional stats update address
stats_update_address: Optional stats update address
"""
self.listen_address = listen_address
self.sock = sock
@@ -225,7 +229,7 @@ def wait_for_completion_or_failure(
"CoreEngineActorManager"]] = None,
coordinator: Optional["DPCoordinator"] = None) -> None:
"""Wait for all processes to complete or detect if any fail.
Raises an exception if any process exits with a non-zero status.
Args:
@@ -368,3 +372,10 @@ def report_usage_stats(
"disable_custom_all_reduce":
vllm_config.parallel_config.disable_custom_all_reduce,
})
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
return record_function(name)
else:
return contextlib.nullcontext()