[Benchmark] Add --async-engine option to benchmark_throughput.py (#7964)
This commit is contained in:
@@ -67,7 +67,7 @@ _running_tasks: Set[asyncio.Task] = set()
|
||||
|
||||
|
||||
def model_is_embedding(model_name: str, trust_remote_code: bool,
|
||||
quantization: str) -> bool:
|
||||
quantization: Optional[str]) -> bool:
|
||||
return ModelConfig(model=model_name,
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
@@ -96,13 +96,6 @@ async def lifespan(app: FastAPI):
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client(
|
||||
args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
||||
"""
|
||||
Create AsyncEngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
@@ -112,14 +105,37 @@ async def build_async_engine_client(
|
||||
# Backend itself still global for the silly lil' health handler
|
||||
global async_engine_client
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
|
||||
async_engine_client = engine # type: ignore[assignment]
|
||||
yield engine
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def build_async_engine_client_from_engine_args(
|
||||
engine_args: AsyncEngineArgs,
|
||||
disable_frontend_multiprocessing: bool = False,
|
||||
) -> AsyncIterator[Optional[AsyncEngineClient]]:
|
||||
"""
|
||||
Create AsyncEngineClient, either:
|
||||
- in-process using the AsyncLLMEngine Directly
|
||||
- multiprocess using AsyncLLMEngine RPC
|
||||
|
||||
Returns the Client or None if the creation failed.
|
||||
"""
|
||||
|
||||
# If manually triggered or embedding model, use AsyncLLMEngine in process.
|
||||
# TODO: support embedding model via RPC.
|
||||
if (model_is_embedding(args.model, args.trust_remote_code,
|
||||
args.quantization)
|
||||
or args.disable_frontend_multiprocessing):
|
||||
async_engine_client = AsyncLLMEngine.from_engine_args(
|
||||
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
||||
engine_args.quantization)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
yield async_engine_client
|
||||
try:
|
||||
yield engine_client
|
||||
finally:
|
||||
engine_client.shutdown_background_loop()
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
@@ -148,7 +164,6 @@ async def build_async_engine_client(
|
||||
# NOTE: Actually, this is not true yet. We still need to support
|
||||
# embedding models via RPC (see TODO above)
|
||||
rpc_client = AsyncEngineRPCClient(rpc_path)
|
||||
async_engine_client = rpc_client # type: ignore
|
||||
|
||||
# Start RPCServer in separate process (holds the AsyncLLMEngine).
|
||||
context = multiprocessing.get_context("spawn")
|
||||
@@ -174,7 +189,7 @@ async def build_async_engine_client(
|
||||
yield None
|
||||
return
|
||||
|
||||
yield async_engine_client
|
||||
yield rpc_client # type: ignore[misc]
|
||||
finally:
|
||||
# Ensure rpc server process was terminated
|
||||
rpc_server_process.terminate()
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import uuid4
|
||||
import cloudpickle
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
from zmq import Frame # type: ignore[attr-defined]
|
||||
from zmq.asyncio import Socket
|
||||
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
@@ -214,6 +215,7 @@ class AsyncEngineRPCClient:
|
||||
|
||||
# Await the data from the Server.
|
||||
frame = await socket.recv(copy=False)
|
||||
assert isinstance(frame, Frame)
|
||||
data = pickle.loads(frame.buffer)
|
||||
|
||||
if isinstance(data, Exception):
|
||||
@@ -247,6 +249,7 @@ class AsyncEngineRPCClient:
|
||||
f"{self._data_timeout} ms")
|
||||
|
||||
frame = await socket.recv(copy=False)
|
||||
assert isinstance(frame, Frame)
|
||||
return pickle.loads(frame.buffer)
|
||||
|
||||
# Make a new socket connection.
|
||||
@@ -395,6 +398,7 @@ class AsyncEngineRPCClient:
|
||||
# Stream back the results from the RPC Server.
|
||||
while not finished:
|
||||
message = await socket.recv(copy=False)
|
||||
assert isinstance(message, Frame)
|
||||
request_output = pickle.loads(message.buffer)
|
||||
|
||||
if isinstance(request_output, Exception):
|
||||
|
||||
Reference in New Issue
Block a user