[Bugfix] Support prompt_logprobs==0 (#5217)
This commit is contained in:
@@ -224,7 +224,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 1
|
||||
assert len(choice.logprobs.top_logprobs[0]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -246,7 +246,7 @@ async def test_some_logprobs(server, client: openai.AsyncOpenAI,
|
||||
assert choice.logprobs is not None
|
||||
assert choice.logprobs.token_logprobs is not None
|
||||
assert choice.logprobs.top_logprobs is not None
|
||||
assert len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
assert 5 <= len(choice.logprobs.top_logprobs[0]) <= 6
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -1217,8 +1217,9 @@ number: "1" | "2"
|
||||
"model_name",
|
||||
[MODEL_NAME, "zephyr-lora", "zephyr-lora2"],
|
||||
)
|
||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
||||
async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
model_name: str):
|
||||
model_name: str, logprobs_arg: int):
|
||||
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME)
|
||||
# test using text and token IDs
|
||||
for prompt in ("Hello, my name is", [0, 0, 0, 0, 0]):
|
||||
@@ -1227,7 +1228,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
max_tokens=5,
|
||||
temperature=0.0,
|
||||
echo=True,
|
||||
logprobs=1)
|
||||
logprobs=logprobs_arg)
|
||||
|
||||
prompt_text = tokenizer.decode(prompt) if isinstance(prompt,
|
||||
list) else prompt
|
||||
@@ -1240,6 +1241,9 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI,
|
||||
and logprobs.token_logprobs[0] is None)
|
||||
assert (len(logprobs.top_logprobs) > 5
|
||||
and logprobs.top_logprobs[0] is None)
|
||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
||||
assert max(logprobs_arg,
|
||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
||||
assert len(logprobs.tokens) > 5
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user