[Bugfix] Do not crash V0 engine on input errors (#13101)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.engine.multiprocessing.engine import MQLLMEngine
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sequence import SequenceGroupMetadata
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@@ -292,3 +293,80 @@ async def test_engine_process_death(tmp_socket):
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
def run_with_evil_input_processing(engine_args: AsyncEngineArgs,
|
||||
ipc_path: str):
|
||||
"""Simulate an exception while preparing inputs for the model.
|
||||
In the wild, this could be something like a multimodal input processor
|
||||
failing on invalid image data."""
|
||||
|
||||
# Make engine.
|
||||
engine = MQLLMEngine.from_engine_args(
|
||||
engine_args=engine_args,
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT,
|
||||
ipc_path=ipc_path)
|
||||
|
||||
runner = engine.engine.model_executor.driver_worker.worker.model_runner
|
||||
|
||||
# Raise error in the model runner when adding a sequence group.
|
||||
# See class ModelInputForGPUBuilder
|
||||
def raiser(_, seq_group_metadata: SequenceGroupMetadata):
|
||||
if seq_group_metadata.request_id.startswith("evil"):
|
||||
raise RAISED_ERROR(RAISED_VALUE)
|
||||
|
||||
runner.builder.per_seq_group_compute_fns.append(raiser)
|
||||
|
||||
# Run engine.
|
||||
engine.start()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_inputs(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_input_processing) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# Engine should be healthy
|
||||
await client.check_health()
|
||||
|
||||
async def run_failing_request():
|
||||
async for _ in client.generate(
|
||||
prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id="evil" + str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
async def run_passing_request():
|
||||
async for _ in client.generate(
|
||||
prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(max_tokens=10),
|
||||
request_id=str(uuid.uuid4())):
|
||||
pass
|
||||
|
||||
passing_tasks = [
|
||||
asyncio.create_task(run_passing_request()) for _ in range(10)
|
||||
]
|
||||
failing_tasks = [
|
||||
asyncio.create_task(run_failing_request()) for _ in range(10)
|
||||
]
|
||||
await asyncio.gather(*failing_tasks, return_exceptions=True)
|
||||
await asyncio.gather(*passing_tasks)
|
||||
|
||||
# All the bad inputs should have raised
|
||||
for task in failing_tasks:
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
task.result()
|
||||
|
||||
# But all good inputs should have still succeeded
|
||||
for task in passing_tasks:
|
||||
task.result()
|
||||
|
||||
# And the engine should remain healthy
|
||||
assert not client.errored
|
||||
await client.check_health()
|
||||
|
||||
client.close()
|
||||
|
||||
Reference in New Issue
Block a user