From ae2e93f89b06529dba41b5c0adc1a6d27e921320 Mon Sep 17 00:00:00 2001 From: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com> Date: Fri, 6 Feb 2026 12:33:40 -0800 Subject: [PATCH] [Fix] Fix `logprobs=0` handling for `/inference/v1/generate` endpoint (#34010) Signed-off-by: SumanthRH --- .../entrypoints/openai/test_serving_tokens.py | 26 +++++++++++++++++++ vllm/entrypoints/serve/disagg/serving.py | 5 ++-- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_tokens.py b/tests/entrypoints/openai/test_serving_tokens.py index ee3c60556..aa56dfd6b 100644 --- a/tests/entrypoints/openai/test_serving_tokens.py +++ b/tests/entrypoints/openai/test_serving_tokens.py @@ -87,6 +87,32 @@ async def test_generate_endpoint(client): assert "choices" in data +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_value", [0, 1, 5]) +async def test_generate_logprobs(client, logprobs_value): + payload = { + "model": MODEL_NAME, + "token_ids": [1, 2, 3], + "sampling_params": { + "max_tokens": 5, + "temperature": 0.0, + "logprobs": logprobs_value, + }, + "stream": False, + } + resp = await client.post(GEN_ENDPOINT, json=payload) + resp.raise_for_status() + data = resp.json() + choice = data["choices"][0] + assert choice["logprobs"] is not None + logprobs_content = choice["logprobs"]["content"] + assert len(logprobs_content) == len(choice["token_ids"]) + for entry in logprobs_content: + assert "logprob" in entry + assert len(entry["top_logprobs"]) >= 1 + assert len(entry["top_logprobs"]) == max(logprobs_value, 1) + + @pytest.mark.asyncio async def test_same_response_as_chat_completions(client, tokenizer, messages): token_ids = tokenizer.apply_chat_template( diff --git a/vllm/entrypoints/serve/disagg/serving.py b/vllm/entrypoints/serve/disagg/serving.py index b74b50611..0e61f5ec0 100644 --- a/vllm/entrypoints/serve/disagg/serving.py +++ b/vllm/entrypoints/serve/disagg/serving.py @@ -184,7 +184,7 @@ class ServingTokens(OpenAIServing): out_logprobs = output.logprobs # This is top_logprobs in completions API - if sampling_params.logprobs: + if sampling_params.logprobs is not None: assert out_logprobs is not None, "Did not output logprobs" logprobs = self._create_tokens_logprobs( token_ids=token_ids, @@ -284,7 +284,8 @@ class ServingTokens(OpenAIServing): logprob=max(p[1].logprob, -9999.0), ) for i, p in enumerate(step_top_logprobs.items()) - if num_output_top_logprobs and i < num_output_top_logprobs + if num_output_top_logprobs is not None + and i < max(num_output_top_logprobs, 1) ], ) )