[Frontend] Fix reasoning_tokens for text-based parsers in Responses API (#33513)
Signed-off-by: Jaeyeon Kim <anencore94@gmail.com>
This commit is contained in:
@@ -13,9 +13,13 @@ from openai.types.responses.tool import (
|
||||
Tool,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.entrypoints.mcp.tool_server import ToolServer
|
||||
from vllm.entrypoints.openai.engine.protocol import ErrorResponse
|
||||
from vllm.entrypoints.openai.responses.context import ConversationContext
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.context import ConversationContext, SimpleContext
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
from vllm.entrypoints.openai.responses.serving import (
|
||||
OpenAIServingResponses,
|
||||
@@ -23,6 +27,8 @@ from vllm.entrypoints.openai.responses.serving import (
|
||||
extract_tool_types,
|
||||
)
|
||||
from vllm.inputs.data import TokensPrompt
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class MockConversationContext(ConversationContext):
|
||||
@@ -259,6 +265,87 @@ class TestValidateGeneratorInput:
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch):
|
||||
"""Ensure reasoning_tokens usage is derived from thinking token spans."""
|
||||
|
||||
class FakeTokenizer:
|
||||
def __init__(self):
|
||||
self._vocab = {"<think>": 1, "</think>": 2, "reason": 3, "final": 4}
|
||||
|
||||
def get_vocab(self):
|
||||
return self._vocab
|
||||
|
||||
# Force non-harmony, SimpleContext path
|
||||
monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False)
|
||||
|
||||
engine_client = MagicMock()
|
||||
model_config = MagicMock()
|
||||
model_config.hf_config.model_type = "test"
|
||||
model_config.hf_text_config = MagicMock()
|
||||
model_config.get_diff_sampling_param.return_value = {}
|
||||
engine_client.model_config = model_config
|
||||
engine_client.input_processor = MagicMock()
|
||||
engine_client.io_processor = MagicMock()
|
||||
engine_client.renderer = MagicMock()
|
||||
|
||||
tokenizer = FakeTokenizer()
|
||||
engine_client.renderer.get_tokenizer.return_value = tokenizer
|
||||
|
||||
models = MagicMock()
|
||||
|
||||
serving = OpenAIServingResponses(
|
||||
engine_client=engine_client,
|
||||
models=models,
|
||||
request_logger=None,
|
||||
chat_template=None,
|
||||
chat_template_content_format="auto",
|
||||
reasoning_parser="qwen3",
|
||||
)
|
||||
|
||||
# Build a SimpleContext with thinking tokens in the output.
|
||||
context = SimpleContext()
|
||||
token_ids = [1, 10, 2, 20] # <think> 10 </think> 20 -> reasoning token count = 1
|
||||
completion = CompletionOutput(
|
||||
index=0,
|
||||
text="<think>reason</think>final",
|
||||
token_ids=token_ids,
|
||||
cumulative_logprob=0.0,
|
||||
logprobs=None,
|
||||
finish_reason="stop",
|
||||
stop_reason=None,
|
||||
)
|
||||
req_output = RequestOutput(
|
||||
request_id="req",
|
||||
prompt="hi",
|
||||
prompt_token_ids=[7, 8],
|
||||
prompt_logprobs=None,
|
||||
outputs=[completion],
|
||||
finished=True,
|
||||
num_cached_tokens=0,
|
||||
)
|
||||
context.append_output(req_output)
|
||||
|
||||
async def dummy_result_generator():
|
||||
yield None
|
||||
|
||||
request = ResponsesRequest(input="hi", tools=[], stream=False)
|
||||
sampling_params = SamplingParams(max_tokens=16)
|
||||
metadata = RequestResponseMetadata(request_id="req")
|
||||
|
||||
response = await serving.responses_full_generator(
|
||||
request=request,
|
||||
sampling_params=sampling_params,
|
||||
result_generator=dummy_result_generator(),
|
||||
context=context,
|
||||
model_name="test-model",
|
||||
tokenizer=tokenizer,
|
||||
request_metadata=metadata,
|
||||
)
|
||||
|
||||
assert response.usage.output_tokens_details.reasoning_tokens == 1
|
||||
|
||||
|
||||
class TestExtractAllowedToolsFromMcpRequests:
|
||||
"""Test class for _extract_allowed_tools_from_mcp_requests function"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user