[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)
This commit is contained in:
@@ -10,6 +10,7 @@ import importlib.metadata
|
||||
import importlib.util
|
||||
import inspect
|
||||
import ipaddress
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import resource
|
||||
@@ -20,6 +21,7 @@ import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
@@ -29,8 +31,9 @@ from collections.abc import Hashable, Iterable, Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache, partial, wraps
|
||||
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
|
||||
Dict, Generator, Generic, List, Literal, NamedTuple,
|
||||
Optional, Tuple, Type, TypeVar, Union, overload)
|
||||
Dict, Generator, Generic, Iterator, List, Literal,
|
||||
NamedTuple, Optional, Tuple, Type, TypeVar, Union,
|
||||
overload)
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
@@ -39,6 +42,8 @@ import psutil
|
||||
import torch
|
||||
import torch.types
|
||||
import yaml
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from packaging.version import Version
|
||||
from torch.library import Library
|
||||
from typing_extensions import ParamSpec, TypeIs, assert_never
|
||||
@@ -1844,7 +1849,7 @@ def memory_profiling(
|
||||
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
@@ -1859,3 +1864,82 @@ def set_ulimit(target_soft_limit=65535):
|
||||
"with error %s. This can cause fd limit errors like"
|
||||
"`OSError: [Errno 24] Too many open files`. Consider "
|
||||
"increasing with ulimit -n", current_soft, e)
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
|
||||
def get_exception_traceback():
|
||||
etype, value, tb = sys.exc_info()
|
||||
err_str = "".join(traceback.format_exception(etype, value, tb))
|
||||
return err_str
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
|
||||
def make_zmq_socket(
|
||||
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
|
||||
path: str,
|
||||
type: Any,
|
||||
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
socket = ctx.socket(type)
|
||||
|
||||
# Calculate buffer size based on system memory
|
||||
total_mem = mem.total / 1024**3
|
||||
available_mem = mem.available / 1024**3
|
||||
# For systems with substantial memory (>32GB total, >16GB available):
|
||||
# - Set a large 0.5GB buffer to improve throughput
|
||||
# For systems with less memory:
|
||||
# - Use system default (-1) to avoid excessive memory consumption
|
||||
if total_mem > 32 and available_mem > 16:
|
||||
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
|
||||
else:
|
||||
buf_size = -1 # Use system default buffer size
|
||||
|
||||
if type == zmq.constants.PULL:
|
||||
socket.setsockopt(zmq.constants.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
|
||||
socket.connect(path)
|
||||
elif type == zmq.constants.PUSH:
|
||||
socket.setsockopt(zmq.constants.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
|
||||
socket.bind(path)
|
||||
else:
|
||||
raise ValueError(f"Unknown Socket Type: {type}")
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, type)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
def _check_multiproc_method():
|
||||
if (cuda_is_initialized()
|
||||
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
|
||||
logger.warning("CUDA was previously initialized. We must use "
|
||||
"the `spawn` multiprocessing start method. Setting "
|
||||
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
|
||||
"See https://docs.vllm.ai/en/latest/getting_started/"
|
||||
"debugging.html#python-multiprocessing "
|
||||
"for more information.")
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
def get_mp_context():
|
||||
_check_multiproc_method()
|
||||
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
|
||||
return multiprocessing.get_context(mp_method)
|
||||
|
||||
Reference in New Issue
Block a user