diff --git a/tests/entrypoints/openai/responses/test_simple.py b/tests/entrypoints/openai/responses/test_simple.py index a5bec6dfd..db536d2fa 100644 --- a/tests/entrypoints/openai/responses/test_simple.py +++ b/tests/entrypoints/openai/responses/test_simple.py @@ -134,6 +134,53 @@ async def test_streaming_output_consistency(client: OpenAI, model_name: str): ) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_streaming_reasoning_tokens_e2e(client: OpenAI, model_name: str): + """Verify final usage includes reasoning_tokens in streaming mode.""" + response = await client.responses.create( + model=model_name, + input="Compute 17 * 19 and explain briefly.", + reasoning={"effort": "low"}, + temperature=0.0, + stream=True, + ) + + completed_event = None + async for event in response: + if event.type == "response.completed": + completed_event = event + + assert completed_event is not None + assert completed_event.response.status == "completed" + assert completed_event.response.usage is not None + assert completed_event.response.usage.output_tokens_details is not None + assert completed_event.response.usage.output_tokens_details.reasoning_tokens > 0, ( + "Expected reasoning_tokens > 0 for streamed Qwen3 response." + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_non_streaming_reasoning_tokens_e2e(client: OpenAI, model_name: str): + """Verify usage includes reasoning_tokens in non-streaming mode.""" + response = await client.responses.create( + model=model_name, + input="Compute 23 * 17 and explain briefly.", + reasoning={"effort": "low"}, + temperature=0.0, + stream=False, + ) + + assert response is not None + assert response.status == "completed" + assert response.usage is not None + assert response.usage.output_tokens_details is not None + assert response.usage.output_tokens_details.reasoning_tokens > 0, ( + "Expected reasoning_tokens > 0 for non-streamed Qwen3 response." + ) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_max_tokens(client: OpenAI, model_name: str): diff --git a/tests/entrypoints/openai/test_serving_responses.py b/tests/entrypoints/openai/test_serving_responses.py index ff0da632e..5cf07ac0f 100644 --- a/tests/entrypoints/openai/test_serving_responses.py +++ b/tests/entrypoints/openai/test_serving_responses.py @@ -13,9 +13,13 @@ from openai.types.responses.tool import ( Tool, ) +import vllm.envs as envs from vllm.entrypoints.mcp.tool_server import ToolServer -from vllm.entrypoints.openai.engine.protocol import ErrorResponse -from vllm.entrypoints.openai.responses.context import ConversationContext +from vllm.entrypoints.openai.engine.protocol import ( + ErrorResponse, + RequestResponseMetadata, +) +from vllm.entrypoints.openai.responses.context import ConversationContext, SimpleContext from vllm.entrypoints.openai.responses.protocol import ResponsesRequest from vllm.entrypoints.openai.responses.serving import ( OpenAIServingResponses, @@ -23,6 +27,8 @@ from vllm.entrypoints.openai.responses.serving import ( extract_tool_types, ) from vllm.inputs.data import TokensPrompt +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import SamplingParams class MockConversationContext(ConversationContext): @@ -259,6 +265,87 @@ class TestValidateGeneratorInput: assert isinstance(result, ErrorResponse) +@pytest.mark.asyncio +async def test_reasoning_tokens_counted_for_text_reasoning_model(monkeypatch): + """Ensure reasoning_tokens usage is derived from thinking token spans.""" + + class FakeTokenizer: + def __init__(self): + self._vocab = {"": 1, "": 2, "reason": 3, "final": 4} + + def get_vocab(self): + return self._vocab + + # Force non-harmony, SimpleContext path + monkeypatch.setattr(envs, "VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT", False) + + engine_client = MagicMock() + model_config = MagicMock() + model_config.hf_config.model_type = "test" + model_config.hf_text_config = MagicMock() + model_config.get_diff_sampling_param.return_value = {} + engine_client.model_config = model_config + engine_client.input_processor = MagicMock() + engine_client.io_processor = MagicMock() + engine_client.renderer = MagicMock() + + tokenizer = FakeTokenizer() + engine_client.renderer.get_tokenizer.return_value = tokenizer + + models = MagicMock() + + serving = OpenAIServingResponses( + engine_client=engine_client, + models=models, + request_logger=None, + chat_template=None, + chat_template_content_format="auto", + reasoning_parser="qwen3", + ) + + # Build a SimpleContext with thinking tokens in the output. + context = SimpleContext() + token_ids = [1, 10, 2, 20] # 10 20 -> reasoning token count = 1 + completion = CompletionOutput( + index=0, + text="reasonfinal", + token_ids=token_ids, + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop", + stop_reason=None, + ) + req_output = RequestOutput( + request_id="req", + prompt="hi", + prompt_token_ids=[7, 8], + prompt_logprobs=None, + outputs=[completion], + finished=True, + num_cached_tokens=0, + ) + context.append_output(req_output) + + async def dummy_result_generator(): + yield None + + request = ResponsesRequest(input="hi", tools=[], stream=False) + sampling_params = SamplingParams(max_tokens=16) + metadata = RequestResponseMetadata(request_id="req") + + response = await serving.responses_full_generator( + request=request, + sampling_params=sampling_params, + result_generator=dummy_result_generator(), + context=context, + model_name="test-model", + tokenizer=tokenizer, + request_metadata=metadata, + ) + + assert response.usage.output_tokens_details.reasoning_tokens == 1 + + class TestExtractAllowedToolsFromMcpRequests: """Test class for _extract_allowed_tools_from_mcp_requests function""" diff --git a/tests/reasoning/test_base_thinking_reasoning_parser.py b/tests/reasoning/test_base_thinking_reasoning_parser.py index 8c69f75a3..f4d74ceee 100644 --- a/tests/reasoning/test_base_thinking_reasoning_parser.py +++ b/tests/reasoning/test_base_thinking_reasoning_parser.py @@ -167,6 +167,23 @@ class TestBaseThinkingReasoningParserMethods: is False ) + def test_count_reasoning_tokens(self, test_tokenizer): + """Count tokens between start/end markers.""" + parser = TestThinkingReasoningParser(test_tokenizer) + start = parser.start_token_id + end = parser.end_token_id + token_ids = [0, start, 11, 12, end, 99] + assert parser.count_reasoning_tokens(token_ids) == 2 + + def test_count_reasoning_tokens_nested(self, test_tokenizer): + """Ensure nested thinking spans count all inner tokens safely.""" + parser = TestThinkingReasoningParser(test_tokenizer) + s = parser.start_token_id + e = parser.end_token_id + token_ids = [s, 1, s, 2, e, 3, e] + # Tokens 1,2,3 are inside reasoning (depth>0) => 3 tokens + assert parser.count_reasoning_tokens(token_ids) == 3 + def test_extract_content_ids(self, test_tokenizer): """Test the extract_content_ids method.""" parser = TestThinkingReasoningParser(test_tokenizer) diff --git a/vllm/entrypoints/openai/responses/context.py b/vllm/entrypoints/openai/responses/context.py index c09d0fb97..9559e7948 100644 --- a/vllm/entrypoints/openai/responses/context.py +++ b/vllm/entrypoints/openai/responses/context.py @@ -280,7 +280,6 @@ class ParsableContext(ConversationContext): self.num_prompt_tokens = 0 self.num_output_tokens = 0 self.num_cached_tokens = 0 - # TODO: num_reasoning_tokens is not implemented yet. self.num_reasoning_tokens = 0 # not implemented yet for ParsableContext self.all_turn_metrics: list[TurnMetrics] = [] @@ -308,12 +307,15 @@ class ParsableContext(ConversationContext): self.input_messages: list[ResponseRawMessageAndToken] = [] self.output_messages: list[ResponseRawMessageAndToken] = [] + self._accumulated_token_ids: list[int] = [] def append_output(self, output: RequestOutput) -> None: self.num_prompt_tokens = len(output.prompt_token_ids or []) self.num_cached_tokens = output.num_cached_tokens or 0 self.num_output_tokens += len(output.outputs[0].token_ids or []) self.parser.process(output.outputs[0]) + output_token_ids = output.outputs[0].token_ids or [] + self._accumulated_token_ids.extend(output_token_ids) # only store if enable_response_messages is True, save memory if self.request.enable_response_messages: diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index ea422a2b7..e40b6b8f0 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -759,6 +759,19 @@ class OpenAIServingResponses(OpenAIServing): num_generated_tokens = context.num_output_tokens num_cached_tokens = context.num_cached_tokens num_reasoning_tokens = context.num_reasoning_tokens + # For text-based reasoning parsers (e.g., ...), + # HarmonyContext already counts reasoning tokens via channels. + # For Simple/Parsable contexts, derive reasoning_tokens from + # accumulated output token IDs using the parser if not already set. + if ( + num_reasoning_tokens == 0 + and self.parser is not None + and self.parser.reasoning_parser_cls is not None + and isinstance(context, (SimpleContext, ParsableContext)) + ): + reasoning_parser = self.parser.reasoning_parser_cls(tokenizer) + accumulated = getattr(context, "_accumulated_token_ids", []) or [] + num_reasoning_tokens = reasoning_parser.count_reasoning_tokens(accumulated) usage = ResponseUsage( input_tokens=num_prompt_tokens, diff --git a/vllm/reasoning/abs_reasoning_parsers.py b/vllm/reasoning/abs_reasoning_parsers.py index bd13ecf02..496eaaf3f 100644 --- a/vllm/reasoning/abs_reasoning_parsers.py +++ b/vllm/reasoning/abs_reasoning_parsers.py @@ -104,6 +104,25 @@ class ReasoningParser: The extracted content from the input_ids. """ + def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int: + """Count the number of reasoning tokens in a sequence. + + Text-based reasoning models typically wrap their chain-of-thought + between special start/end tokens (e.g., `` ... ``). + Implementations that support reasoning token counting should override + this method. The default implementation returns ``0`` so existing + parsers remain unchanged unless they explicitly opt in. + + Args: + token_ids: Sequence of generated token ids (excluding prompt). + + Returns: + int: Number of tokens that belong to reasoning content. + """ + + # By default, assume the parser cannot detect reasoning spans. + return 0 + @abstractmethod def extract_reasoning( self, diff --git a/vllm/reasoning/basic_parsers.py b/vllm/reasoning/basic_parsers.py index 18bf96d78..c066032fb 100644 --- a/vllm/reasoning/basic_parsers.py +++ b/vllm/reasoning/basic_parsers.py @@ -175,3 +175,23 @@ class BaseThinkingReasoningParser(ReasoningParser): # If generation stops right after end-of-think, return null content final_content = content or None return reasoning, final_content + + def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int: + """Count tokens that fall within start/end thinking markers. + + Uses a depth counter so nested spans are handled safely and stray end + tokens do not drive the counter negative. + """ + count = 0 + depth = 0 + for token_id in token_ids: + if token_id == self.start_token_id: + depth += 1 + continue + if token_id == self.end_token_id: + if depth > 0: + depth -= 1 + continue + if depth > 0: + count += 1 + return count