break execute_model in gpu_model_runner into sub-functions for custom scopes (#24265)
Co-authored-by: Bangsheng Tang <bangsheng@meta.com>
This commit is contained in:
@@ -1,17 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import contextlib
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
from collections.abc import Sequence
|
||||
from contextlib import AbstractContextManager
|
||||
from multiprocessing import connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar,
|
||||
Union, overload)
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler import record_function
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
@@ -155,7 +159,7 @@ def get_engine_client_zmq_addr(local_only: bool,
|
||||
|
||||
class APIServerProcessManager:
|
||||
"""Manages a group of API server processes.
|
||||
|
||||
|
||||
Handles creation, monitoring, and termination of API server worker
|
||||
processes. Also monitors extra processes to check if they are healthy.
|
||||
"""
|
||||
@@ -172,7 +176,7 @@ class APIServerProcessManager:
|
||||
stats_update_address: Optional[str] = None,
|
||||
):
|
||||
"""Initialize and start API server worker processes.
|
||||
|
||||
|
||||
Args:
|
||||
target_server_fn: Function to call for each API server process
|
||||
listen_address: Address to listen for client connections
|
||||
@@ -181,7 +185,7 @@ class APIServerProcessManager:
|
||||
num_servers: Number of API server processes to start
|
||||
input_addresses: Input addresses for each API server
|
||||
output_addresses: Output addresses for each API server
|
||||
stats_update_address: Optional stats update address
|
||||
stats_update_address: Optional stats update address
|
||||
"""
|
||||
self.listen_address = listen_address
|
||||
self.sock = sock
|
||||
@@ -225,7 +229,7 @@ def wait_for_completion_or_failure(
|
||||
"CoreEngineActorManager"]] = None,
|
||||
coordinator: Optional["DPCoordinator"] = None) -> None:
|
||||
"""Wait for all processes to complete or detect if any fail.
|
||||
|
||||
|
||||
Raises an exception if any process exits with a non-zero status.
|
||||
|
||||
Args:
|
||||
@@ -368,3 +372,10 @@ def report_usage_stats(
|
||||
"disable_custom_all_reduce":
|
||||
vllm_config.parallel_config.disable_custom_all_reduce,
|
||||
})
|
||||
|
||||
|
||||
def record_function_or_nullcontext(name: str) -> AbstractContextManager:
|
||||
if envs.VLLM_CUSTOM_SCOPES_FOR_PROFILING:
|
||||
return record_function(name)
|
||||
else:
|
||||
return contextlib.nullcontext()
|
||||
|
||||
Reference in New Issue
Block a user