[V1] [4/N] API Server: ZMQ/MP Utilities (#11541)

This commit is contained in:
Robert Shaw
2024-12-27 20:45:08 -05:00
committed by GitHub
parent a60731247f
commit df04dffade
12 changed files with 242 additions and 210 deletions

View File

@@ -10,6 +10,7 @@ import importlib.metadata
import importlib.util
import inspect
import ipaddress
import multiprocessing
import os
import re
import resource
@@ -20,6 +21,7 @@ import sys
import tempfile
import threading
import time
import traceback
import uuid
import warnings
import weakref
@@ -29,8 +31,9 @@ from collections.abc import Hashable, Iterable, Mapping
from dataclasses import dataclass, field
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generator, Generic, List, Literal, NamedTuple,
Optional, Tuple, Type, TypeVar, Union, overload)
Dict, Generator, Generic, Iterator, List, Literal,
NamedTuple, Optional, Tuple, Type, TypeVar, Union,
overload)
from uuid import uuid4
import numpy as np
@@ -39,6 +42,8 @@ import psutil
import torch
import torch.types
import yaml
import zmq
import zmq.asyncio
from packaging.version import Version
from torch.library import Library
from typing_extensions import ParamSpec, TypeIs, assert_never
@@ -1844,7 +1849,7 @@ def memory_profiling(
result.non_kv_cache_memory_in_bytes = result.non_torch_increase_in_bytes + result.torch_peak_increase_in_bytes + result.weights_memory_in_bytes # noqa
# Adapted from: https://github.com/sgl-project/sglang/blob/f46f394f4d4dbe4aae85403dec006199b34d2840/python/sglang/srt/utils.py#L630 # noqa: E501Curre
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L630 # noqa: E501
def set_ulimit(target_soft_limit=65535):
resource_type = resource.RLIMIT_NOFILE
current_soft, current_hard = resource.getrlimit(resource_type)
@@ -1859,3 +1864,82 @@ def set_ulimit(target_soft_limit=65535):
"with error %s. This can cause fd limit errors like"
"`OSError: [Errno 24] Too many open files`. Consider "
"increasing with ulimit -n", current_soft, e)
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/utils.py#L28 # noqa: E501
def get_exception_traceback():
etype, value, tb = sys.exc_info()
err_str = "".join(traceback.format_exception(etype, value, tb))
return err_str
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
def make_zmq_socket(
ctx: Union[zmq.asyncio.Context, zmq.Context], # type: ignore[name-defined]
path: str,
type: Any,
) -> Union[zmq.Socket, zmq.asyncio.Socket]: # type: ignore[name-defined]
"""Make a ZMQ socket with the proper bind/connect semantics."""
mem = psutil.virtual_memory()
socket = ctx.socket(type)
# Calculate buffer size based on system memory
total_mem = mem.total / 1024**3
available_mem = mem.available / 1024**3
# For systems with substantial memory (>32GB total, >16GB available):
# - Set a large 0.5GB buffer to improve throughput
# For systems with less memory:
# - Use system default (-1) to avoid excessive memory consumption
if total_mem > 32 and available_mem > 16:
buf_size = int(0.5 * 1024**3) # 0.5GB in bytes
else:
buf_size = -1 # Use system default buffer size
if type == zmq.constants.PULL:
socket.setsockopt(zmq.constants.RCVHWM, 0)
socket.setsockopt(zmq.constants.RCVBUF, buf_size)
socket.connect(path)
elif type == zmq.constants.PUSH:
socket.setsockopt(zmq.constants.SNDHWM, 0)
socket.setsockopt(zmq.constants.SNDBUF, buf_size)
socket.bind(path)
else:
raise ValueError(f"Unknown Socket Type: {type}")
return socket
@contextlib.contextmanager
def zmq_socket_ctx(
path: str,
type: Any) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
"""Context manager for a ZMQ socket"""
ctx = zmq.Context(io_threads=2) # type: ignore[attr-defined]
try:
yield make_zmq_socket(ctx, path, type)
except KeyboardInterrupt:
logger.debug("Got Keyboard Interrupt.")
finally:
ctx.destroy(linger=0)
def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
"the `spawn` multiprocessing start method. Setting "
"VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. "
"See https://docs.vllm.ai/en/latest/getting_started/"
"debugging.html#python-multiprocessing "
"for more information.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
def get_mp_context():
_check_multiproc_method()
mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD
return multiprocessing.get_context(mp_method)