support returning tokenids in responses api (#33212)

Signed-off-by: Christian Munley <cmunley@nvidia.com>
This commit is contained in:
cmunley1
2026-01-29 00:52:39 -08:00
committed by GitHub
parent 53fc166402
commit 3bba2edb0f
3 changed files with 50 additions and 9 deletions

View File

@@ -121,3 +121,42 @@ async def test_chat_return_tokens_as_token_ids_completion(server_fixture):
for logprob_content in response.choices[0].logprobs.content:
token_ids.append(int(logprob_content.token.removeprefix("token_id:")))
assert tokenizer.decode(token_ids, skip_special_tokens=True) == text
def test_responses_api_logprobs_with_return_tokens_as_token_ids():
"""Test that return_tokens_as_token_ids works in Responses API logprobs."""
from unittest.mock import MagicMock
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
from vllm.logprobs import Logprob as SampleLogprob
serving = MagicMock(spec=OpenAIServingResponses)
serving.return_tokens_as_token_ids = True
serving._get_decoded_token = OpenAIServing._get_decoded_token
tokenizer = MagicMock()
tokenizer.decode = lambda token_id: "decoded"
token_ids = [100, 200, 300]
sample_logprobs = [
{100: SampleLogprob(logprob=-0.5, decoded_token="hello")},
{200: SampleLogprob(logprob=-1.2, decoded_token="world")},
{300: SampleLogprob(logprob=-0.8, decoded_token="!")},
]
result = OpenAIServingResponses._create_response_logprobs(
serving,
token_ids=token_ids,
logprobs=sample_logprobs,
tokenizer=tokenizer,
top_logprobs=1,
)
assert len(result) == 3
assert result[0].token == "token_id:100"
assert result[1].token == "token_id:200"
assert result[2].token == "token_id:300"
assert result[0].logprob == -0.5
assert result[1].logprob == -1.2
assert result[2].logprob == -0.8

View File

@@ -1528,7 +1528,7 @@ class OpenAIServing:
"Unable to get tokenizer because `skip_tokenizer_init=True`"
)
return tokenizer.decode(token_id)
return tokenizer.decode([token_id])
def _is_model_supported(self, model_name: str | None) -> bool:
if not model_name:

View File

@@ -836,10 +836,11 @@ class OpenAIServingResponses(OpenAIServing):
for i, (token_id, _logprob) in enumerate(logprobs.items()):
if i >= top_logprobs:
break
text = (
_logprob.decoded_token
if _logprob.decoded_token is not None
else tokenizer.decode([token_id])
text = self._get_decoded_token(
logprob=_logprob,
token_id=token_id,
tokenizer=tokenizer,
return_as_token_id=self.return_tokens_as_token_ids,
)
out.append(
LogprobTopLogprob(
@@ -865,10 +866,11 @@ class OpenAIServingResponses(OpenAIServing):
for i, token_id in enumerate(token_ids):
logprob = logprobs[i]
token_logprob = logprob[token_id]
text = (
token_logprob.decoded_token
if token_logprob.decoded_token is not None
else tokenizer.decode([token_id])
text = self._get_decoded_token(
logprob=token_logprob,
token_id=token_id,
tokenizer=tokenizer,
return_as_token_id=self.return_tokens_as_token_ids,
)
out.append(
Logprob(