[Core][Bugfix][Perf] Introduce MQLLMEngine to avoid asyncio OH (#8157)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Alexander Matveev
2024-09-18 09:56:58 -04:00
committed by GitHub
parent 9d104b5beb
commit 7c7714d856
36 changed files with 1464 additions and 1169 deletions

View File

@@ -26,7 +26,9 @@ import vllm.envs as envs
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import AsyncEngineClient
from vllm.engine.multiprocessing.client import MQLLMEngineClient
from vllm.engine.multiprocessing.engine import run_mp_engine
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.launcher import serve_http
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
@@ -44,8 +46,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
TokenizeRequest,
TokenizeResponse,
UnloadLoraAdapterRequest)
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
@@ -67,29 +67,16 @@ logger = init_logger('vllm.entrypoints.openai.api_server')
_running_tasks: Set[asyncio.Task] = set()
def model_is_embedding(model_name: str, trust_remote_code: bool,
quantization: Optional[str],
revision: Optional[str]) -> bool:
return ModelConfig(model=model_name,
revision=revision,
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=trust_remote_code,
quantization=quantization,
seed=0,
dtype="auto").embedding_mode
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
async_engine_client = app.state.engine_client
engine_client: EngineClient = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
await asyncio.sleep(10.)
await engine_client.do_log_stats()
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
@@ -108,9 +95,9 @@ async def lifespan(app: FastAPI):
@asynccontextmanager
async def build_async_engine_client(
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
args: Namespace) -> AsyncIterator[Optional[EngineClient]]:
# Context manager to handle async_engine_client lifecycle
# Context manager to handle engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
engine_args = AsyncEngineArgs.from_cli_args(args)
@@ -123,19 +110,18 @@ async def build_async_engine_client(
async def build_async_engine_client_from_engine_args(
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
) -> AsyncIterator[Optional[EngineClient]]:
"""
Create AsyncEngineClient, either:
Create EngineClient, either:
- in-process using the AsyncLLMEngine Directly
- multiprocess using AsyncLLMEngine RPC
Returns the Client or None if the creation failed.
"""
# If manually triggered or embedding model, use AsyncLLMEngine in process.
# TODO: support embedding model via RPC.
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization, engine_args.revision)
# Fall back
# TODO: fill out feature matrix.
if (MQLLMEngineClient.is_unsupported_config(engine_args)
or disable_frontend_multiprocessing):
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
@@ -173,56 +159,60 @@ async def build_async_engine_client_from_engine_args(
"and vLLM will properly handle cleanup.")
# Select random path for IPC.
rpc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for RPC Path.",
rpc_path)
ipc_path = get_open_zmq_ipc_path()
logger.info("Multiprocessing frontend to use %s for IPC Path.",
ipc_path)
# Build RPCClient, which conforms to AsyncEngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
rpc_client = AsyncEngineRPCClient(rpc_path)
# Start RPCServer in separate process (holds the AsyncLLMEngine).
context = multiprocessing.get_context("spawn")
# Start RPCServer in separate process (holds the LLMEngine).
# the current process might have CUDA context,
# so we need to spawn a new process
rpc_server_process = context.Process(
target=run_rpc_server,
args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
rpc_server_process.start()
logger.info("Started engine process with PID %d",
rpc_server_process.pid)
context = multiprocessing.get_context("spawn")
engine_process = context.Process(target=run_mp_engine,
args=(engine_args,
UsageContext.OPENAI_API_SERVER,
ipc_path))
engine_process.start()
logger.info("Started engine process with PID %d", engine_process.pid)
# Build RPCClient, which conforms to EngineClient Protocol.
# NOTE: Actually, this is not true yet. We still need to support
# embedding models via RPC (see TODO above)
engine_config = engine_args.create_engine_config()
mp_engine_client = MQLLMEngineClient(ipc_path, engine_config)
try:
while True:
try:
await rpc_client.setup()
await mp_engine_client.setup()
break
except TimeoutError:
if not rpc_server_process.is_alive():
logger.error(
"RPCServer process died before responding "
"to readiness probe")
if not engine_process.is_alive():
logger.error("Engine process died before responding "
"to readiness probe")
yield None
return
yield rpc_client # type: ignore[misc]
yield mp_engine_client # type: ignore[misc]
finally:
# Ensure rpc server process was terminated
rpc_server_process.terminate()
engine_process.terminate()
# Close all open connections to the backend
rpc_client.close()
mp_engine_client.close()
# Wait for server process to join
rpc_server_process.join()
# Wait for engine process to join
engine_process.join(4)
if engine_process.exitcode is None:
# Kill if taking longer than 5 seconds to stop
engine_process.kill()
# Lazy import for prometheus multiprocessing.
# We need to set PROMETHEUS_MULTIPROC_DIR environment variable
# before prometheus_client is imported.
# See https://prometheus.github.io/client_python/multiprocess/
from prometheus_client import multiprocess
multiprocess.mark_process_dead(rpc_server_process.pid)
multiprocess.mark_process_dead(engine_process.pid)
router = APIRouter()
@@ -270,7 +260,7 @@ def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> AsyncEngineClient:
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@@ -473,7 +463,7 @@ def build_app(args: Namespace) -> FastAPI:
def init_app_state(
async_engine_client: AsyncEngineClient,
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
@@ -488,11 +478,11 @@ def init_app_state(
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
state.engine_client = async_engine_client
state.engine_client = engine_client
state.log_stats = not args.disable_log_stats
state.openai_serving_chat = OpenAIServingChat(
async_engine_client,
engine_client,
model_config,
served_model_names,
args.response_role,
@@ -504,7 +494,7 @@ def init_app_state(
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
state.openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@@ -513,13 +503,13 @@ def init_app_state(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
state.openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client,
engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
state.openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client,
engine_client,
model_config,
served_model_names,
lora_modules=args.lora_modules,
@@ -541,21 +531,20 @@ async def run_server(args, **uvicorn_kwargs) -> None:
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as async_engine_client:
async with build_async_engine_client(args) as engine_client:
# If None, creation of the client failed and we exit.
if async_engine_client is None:
if engine_client is None:
return
app = build_app(args)
model_config = await async_engine_client.get_model_config()
init_app_state(async_engine_client, model_config, app.state, args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
temp_socket.close()
shutdown_task = await serve_http(
app,
limit_concurrency=async_engine_client.limit_concurrency,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,