Revert "rename PromptInputs and inputs with backward compatibility (#8760) (#8810)

This commit is contained in:
Simon Mo
2024-09-25 10:36:26 -07:00
committed by GitHub
parent 873edda6cf
commit 4f1ba0844b
21 changed files with 245 additions and 438 deletions

View File

@@ -61,7 +61,7 @@ async def test_evil_forward(tmp_socket):
# Throws an error in first forward pass.
with pytest.raises(RAISED_ERROR):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@@ -69,7 +69,7 @@ async def test_evil_forward(tmp_socket):
# Engine is errored, should get ENGINE_DEAD_ERROR.
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@@ -118,7 +118,7 @@ async def test_failed_health_check(tmp_socket):
# Generate call should throw ENGINE_DEAD_ERROR
with pytest.raises(MQEngineDeadError):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id=uuid.uuid4()):
pass
@@ -160,7 +160,7 @@ async def test_failed_abort(tmp_socket):
# with reference to the original KeyError("foo")
with pytest.raises(MQEngineDeadError) as execinfo:
async for _ in client.generate(
prompt="Hello my name is",
inputs="Hello my name is",
sampling_params=SamplingParams(max_tokens=10),
request_id=uuid.uuid4()):
pass
@@ -183,7 +183,7 @@ async def test_bad_request(tmp_socket):
# Invalid request should fail, but not crash the server.
with pytest.raises(ValueError):
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-1",
lora_request=LoRARequest(
@@ -192,7 +192,7 @@ async def test_bad_request(tmp_socket):
pass
# This request should be okay.
async for _ in client.generate(prompt="Hello my name is",
async for _ in client.generate(inputs="Hello my name is",
sampling_params=SamplingParams(),
request_id="abcd-2"):
pass