[Fix] Fix logprobs=0 handling for /inference/v1/generate endpoint (#34010)
Signed-off-by: SumanthRH <sumanthrh99@gmail.com>
This commit is contained in:
@@ -87,6 +87,32 @@ async def test_generate_endpoint(client):
|
|||||||
assert "choices" in data
|
assert "choices" in data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("logprobs_value", [0, 1, 5])
|
||||||
|
async def test_generate_logprobs(client, logprobs_value):
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"token_ids": [1, 2, 3],
|
||||||
|
"sampling_params": {
|
||||||
|
"max_tokens": 5,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"logprobs": logprobs_value,
|
||||||
|
},
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
resp = await client.post(GEN_ENDPOINT, json=payload)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
choice = data["choices"][0]
|
||||||
|
assert choice["logprobs"] is not None
|
||||||
|
logprobs_content = choice["logprobs"]["content"]
|
||||||
|
assert len(logprobs_content) == len(choice["token_ids"])
|
||||||
|
for entry in logprobs_content:
|
||||||
|
assert "logprob" in entry
|
||||||
|
assert len(entry["top_logprobs"]) >= 1
|
||||||
|
assert len(entry["top_logprobs"]) == max(logprobs_value, 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_same_response_as_chat_completions(client, tokenizer, messages):
|
async def test_same_response_as_chat_completions(client, tokenizer, messages):
|
||||||
token_ids = tokenizer.apply_chat_template(
|
token_ids = tokenizer.apply_chat_template(
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ class ServingTokens(OpenAIServing):
|
|||||||
out_logprobs = output.logprobs
|
out_logprobs = output.logprobs
|
||||||
|
|
||||||
# This is top_logprobs in completions API
|
# This is top_logprobs in completions API
|
||||||
if sampling_params.logprobs:
|
if sampling_params.logprobs is not None:
|
||||||
assert out_logprobs is not None, "Did not output logprobs"
|
assert out_logprobs is not None, "Did not output logprobs"
|
||||||
logprobs = self._create_tokens_logprobs(
|
logprobs = self._create_tokens_logprobs(
|
||||||
token_ids=token_ids,
|
token_ids=token_ids,
|
||||||
@@ -284,7 +284,8 @@ class ServingTokens(OpenAIServing):
|
|||||||
logprob=max(p[1].logprob, -9999.0),
|
logprob=max(p[1].logprob, -9999.0),
|
||||||
)
|
)
|
||||||
for i, p in enumerate(step_top_logprobs.items())
|
for i, p in enumerate(step_top_logprobs.items())
|
||||||
if num_output_top_logprobs and i < num_output_top_logprobs
|
if num_output_top_logprobs is not None
|
||||||
|
and i < max(num_output_top_logprobs, 1)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user