[BUGFIX] [FRONTEND] Correct chat logprobs (#5029)
Co-authored-by: Breno Faria <breno.faria@intrafind.com>
This commit is contained in:
@@ -184,6 +184,26 @@ async def test_single_completion(server, client: openai.AsyncOpenAI,
|
||||
completion.choices[0].text) >= 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_no_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=None,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
@@ -203,7 +223,72 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_some_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
# test using token IDs
|
||||
completion = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=5,
|
||||
)
|
||||
choice = completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_too_many_completion_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=6,
|
||||
)
|
||||
...
|
||||
with pytest.raises(
|
||||
(openai.BadRequestError, openai.APIError)): # test using token IDs
|
||||
stream = await client.completions.create(
|
||||
model=MODEL_NAME,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=6,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
# the server should still work afterwards
|
||||
completion = await client.completions.create(
|
||||
model=model_name,
|
||||
prompt=[0, 0, 0, 0, 0],
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
)
|
||||
completion = completion.choices[0].text
|
||||
assert completion is not None and len(completion) >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -233,8 +318,10 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||
chat_completion.choices) == 1
|
||||
assert chat_completion.choices[0].message is not None
|
||||
assert chat_completion.choices[0].logprobs is not None
|
||||
assert chat_completion.choices[0].logprobs.top_logprobs is not None
|
||||
assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5
|
||||
assert chat_completion.choices[0].logprobs.content[
|
||||
0].top_logprobs is not None
|
||||
assert len(
|
||||
chat_completion.choices[0].logprobs.content[0].top_logprobs) == 5
|
||||
message = chat_completion.choices[0].message
|
||||
assert message.content is not None and len(message.content) >= 10
|
||||
assert message.role == "assistant"
|
||||
@@ -252,9 +339,13 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
@pytest.mark.parametrize(
|
||||
# first test base model, then test loras
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
async def test_no_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
@@ -263,13 +354,92 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# Default max_logprobs is 5, so this should raise an error
|
||||
chat_completion = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=False)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
# just test 1 lora hereafter
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_zero_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=True,
|
||||
top_logprobs=0)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.content is not None
|
||||
assert len(choice.logprobs.content[0].top_logprobs) <= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora"],
|
||||
)
|
||||
async def test_some_logprobs_chat(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
chat_completion = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
logprobs=True,
|
||||
top_logprobs=5)
|
||||
|
||||
choice = chat_completion.choices[0]
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.content is not None
|
||||
assert len(choice.logprobs.content[0].top_logprobs) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_too_many_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
# Default max_logprobs is 20, so this should raise an error
|
||||
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
||||
stream = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
top_logprobs=21,
|
||||
stream=True)
|
||||
async for chunk in stream:
|
||||
...
|
||||
@@ -279,25 +449,9 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI,
|
||||
messages=messages,
|
||||
max_tokens=10,
|
||||
logprobs=True,
|
||||
top_logprobs=10,
|
||||
top_logprobs=30,
|
||||
stream=False)
|
||||
|
||||
with pytest.raises((openai.BadRequestError, openai.APIError)):
|
||||
stream = await client.completions.create(model=model_name,
|
||||
prompt="Test",
|
||||
max_tokens=10,
|
||||
logprobs=10,
|
||||
stream=True)
|
||||
async for chunk in stream:
|
||||
...
|
||||
|
||||
with pytest.raises(openai.BadRequestError):
|
||||
await client.completions.create(model=model_name,
|
||||
prompt="Test",
|
||||
max_tokens=10,
|
||||
logprobs=10,
|
||||
stream=False)
|
||||
|
||||
# the server should still work afterwards
|
||||
chat_completion = await client.chat.completions.create(model=model_name,
|
||||
messages=messages,
|
||||
@@ -744,13 +898,12 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI,
|
||||
top_logprobs=5,
|
||||
extra_body=dict(guided_choice=TEST_CHOICE,
|
||||
guided_decoding_backend=guided_decoding_backend))
|
||||
top_logprobs = chat_completion.choices[0].logprobs.top_logprobs
|
||||
top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs
|
||||
|
||||
# -9999.0 is the minimum logprob returned by OpenAI
|
||||
assert all(
|
||||
isinstance(logprob, float) and logprob >= -9999.0
|
||||
for token_dict in top_logprobs
|
||||
for token, logprob in token_dict.items())
|
||||
isinstance(token.logprob, float) and token.logprob >= -9999.0
|
||||
for token in top_logprobs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user