[Core][Bugfix][Perf] Introduce MQLLMEngine to avoid asyncio OH (#8157)

Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Alexander Matveev
2024-09-18 09:56:58 -04:00
committed by GitHub
parent 9d104b5beb
commit 7c7714d856
36 changed files with 1464 additions and 1169 deletions

View File

@@ -0,0 +1,67 @@
"""Test that aborting is handled properly."""
import asyncio
import tempfile
import uuid
import pytest
from tests.mq_llm_engine.utils import RemoteMQLLMEngine, generate
from vllm.engine.arg_utils import AsyncEngineArgs
MODEL = "google/gemma-1.1-2b-it"
ENGINE_ARGS = AsyncEngineArgs(model=MODEL)
RAISED_ERROR = KeyError
RAISED_VALUE = "foo"
EXPECTED_TOKENS = 250
@pytest.fixture(scope="function")
def tmp_socket():
with tempfile.TemporaryDirectory() as td:
yield f"ipc://{td}/{uuid.uuid4()}"
@pytest.mark.asyncio
async def test_abort(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket) as engine:
client = await engine.make_client()
request_id_to_be_aborted = "request-aborted"
request_ids_a = [f"request-a-{idx}" for idx in range(10)]
request_ids_b = [f"request-b-{idx}" for idx in range(10)]
# Requests started before one to be aborted.
tasks = []
for request_id in request_ids_a:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Aborted.
task_aborted = asyncio.create_task(
generate(client, request_id_to_be_aborted, EXPECTED_TOKENS))
# Requests started after one to be aborted.
for request_id in request_ids_b:
tasks.append(
asyncio.create_task(
generate(client, request_id, EXPECTED_TOKENS)))
# Actually abort.
await asyncio.sleep(0.5)
await client.abort(request_id_to_be_aborted)
# Confirm that we got all the EXPECTED tokens from the requests.
for task in tasks:
count, request_id = await task
assert count == EXPECTED_TOKENS, (
f"{request_id} generated only {count} tokens")
# Cancel task (this will hang indefinitely if not).
task_aborted.cancel()
# Shutdown.
client.close()