[Frontend] Don't log duplicate error stacktrace for every request in the batch (#9023)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
@@ -59,15 +59,7 @@ async def test_evil_forward(tmp_socket):
|
||||
await asyncio.sleep(2.0)
|
||||
await client.check_health()
|
||||
|
||||
# Throws an error in first forward pass.
|
||||
with pytest.raises(RAISED_ERROR):
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
assert client.errored
|
||||
|
||||
# Engine is errored, should get ENGINE_DEAD_ERROR.
|
||||
# Throws an error that should get ENGINE_DEAD_ERROR.
|
||||
with pytest.raises(MQEngineDeadError):
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=SamplingParams(),
|
||||
@@ -149,7 +141,7 @@ async def test_failed_abort(tmp_socket):
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# Firsh check health should work.
|
||||
# First check health should work.
|
||||
await client.check_health()
|
||||
|
||||
# Trigger an abort on the client side.
|
||||
@@ -174,6 +166,45 @@ async def test_failed_abort(tmp_socket):
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_error(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
ipc_path=tmp_socket,
|
||||
run_fn=run_with_evil_abort) as engine:
|
||||
|
||||
client = await engine.make_client()
|
||||
assert client.is_running
|
||||
|
||||
# First check health should work.
|
||||
await client.check_health()
|
||||
|
||||
# Batch of requests
|
||||
async def do_generate(client):
|
||||
# min_tokens=2048 to keep busy the engine busy
|
||||
# to get enough time to get process a request
|
||||
# that will crash the engine
|
||||
params = SamplingParams(min_tokens=2048, max_tokens=2048)
|
||||
async for _ in client.generate(prompt="Hello my name is",
|
||||
sampling_params=params,
|
||||
request_id=uuid.uuid4()):
|
||||
pass
|
||||
|
||||
tasks = [asyncio.create_task(do_generate(client)) for _ in range(10)]
|
||||
|
||||
# This request will force a processing batch to raise
|
||||
# an exception and next the engine get errored
|
||||
await client.abort(request_id="foo")
|
||||
|
||||
# The batch of those request failed, then they
|
||||
# should get the same exception as a MQEngineDeadError.
|
||||
errors = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
for e in errors:
|
||||
assert isinstance(e, MQEngineDeadError)
|
||||
assert "KeyError" in repr(e)
|
||||
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_request(tmp_socket):
|
||||
with RemoteMQLLMEngine(engine_args=ENGINE_ARGS,
|
||||
|
||||
Reference in New Issue
Block a user