Push logprob generation to LLMEngine (#3065)

Co-authored-by: Avnish Narayan <avnish@anyscale.com>
This commit is contained in:
Antoni Baum
2024-03-04 11:54:06 -08:00
committed by GitHub
parent 76e8a70476
commit 22de45235c
13 changed files with 551 additions and 331 deletions

View File

@@ -213,14 +213,14 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10)
top_logprobs=5)
assert chat_completion.id is not None
assert chat_completion.choices is not None and len(
chat_completion.choices) == 1
assert chat_completion.choices[0].message is not None
assert chat_completion.choices[0].logprobs is not None
assert chat_completion.choices[0].logprobs.top_logprobs is not None
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 10
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 10
assert message.role == "assistant"
@@ -229,7 +229,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
# test multi-turn dialogue
messages.append({"role": "user", "content": "express your result in json"})
chat_completion = await client.chat.completions.create(
model=MODEL_NAME,
model=model_name,
messages=messages,
max_tokens=10,
)
@@ -237,6 +237,61 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
model_name: str):
messages = [{
"role": "system",
"content": "you are a helpful assistant"
}, {
"role": "user",
"content": "what is 1+1?"
}]
# Default max_logprobs is 5, so this should raise an error
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=True)
async for chunk in stream:
...
with pytest.raises(openai.BadRequestError):
await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
logprobs=True,
top_logprobs=10,
stream=False)
with pytest.raises((openai.BadRequestError, openai.APIError)):
stream = await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=True)
async for chunk in stream:
...
with pytest.raises(openai.BadRequestError):
await client.completions.create(model=model_name,
prompt="Test",
max_tokens=10,
logprobs=10,
stream=False)
# the server should still work afterwards
chat_completion = await client.chat.completions.create(model=model_name,
messages=messages,
max_tokens=10,
stream=False)
message = chat_completion.choices[0].message
assert message.content is not None and len(message.content) >= 0
@pytest.mark.parametrize(
# just test 1 lora hereafter
"model_name",