[Bugfix] Do not crash V0 engine on input errors (#13101)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2025-02-26 04:07:29 -07:00
committed by GitHub
parent ec8a5e5386
commit 3f808cc044
5 changed files with 172 additions and 6 deletions

View File

@@ -18,6 +18,7 @@ from vllm.engine.multiprocessing.engine import MQLLMEngine
from vllm.entrypoints.openai.api_server import build_async_engine_client
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroupMetadata
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser
@@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
await client.check_health()
client.close()
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
ipc_path: str):
"""Simulate an exception while preparing inputs for the model.
In the wild, this could be something like a multimodal input processor
failing on invalid image data."""
# Make engine.
engine = MQLLMEngine.from_engine_args(
engine_args=engine_args,
usage_context=UsageContext.UNKNOWN_CONTEXT,
ipc_path=ipc_path)
runner = engine.engine.model_executor.driver_worker.worker.model_runner
# Raise error in the model runner when adding a sequence group.
# See class ModelInputForGPUBuilder
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
if seq_group_metadata.request_id.startswith("evil"):
raise RAISED_ERROR(RAISED_VALUE)
runner.builder.per_seq_group_compute_fns.append(raiser)
# Run engine.
engine.start()
@pytest.mark.asyncio
async def test_failed_inputs(tmp_socket):
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
ipc_path=tmp_socket,
run_fn=run_with_evil_input_processing) as engine:
client = await engine.make_client()
assert client.is_running
# Engine should be healthy
await client.check_health()
async def run_failing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id="evil" + str(uuid.uuid4())):
pass
async def run_passing_request():
async for _ in client.generate(
prompt="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=str(uuid.uuid4())):
pass
passing_tasks = [
asyncio.create_task(run_passing_request()) for _ in range(10)
]
failing_tasks = [
asyncio.create_task(run_failing_request()) for _ in range(10)
]
await asyncio.gather(*failing_tasks, return_exceptions=True)
await asyncio.gather(*passing_tasks)
# All the bad inputs should have raised
for task in failing_tasks:
with pytest.raises(RAISED_ERROR):
task.result()
# But all good inputs should have still succeeded
for task in passing_tasks:
task.result()
# And the engine should remain healthy
assert not client.errored
await client.check_health()
client.close()