diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 48fb66648..8ff686516 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -72,6 +72,7 @@ from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.outputs import CompletionOutput, RequestOutput from vllm.parser import ParserManager +from vllm.reasoning import ReasoningParser from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.tokenizers import TokenizerLike from vllm.tokenizers.mistral import ( @@ -132,7 +133,7 @@ class OpenAIServingChat(OpenAIServing): self.logits_processors = self.model_config.logits_processors # set up reasoning parser - self.reasoning_parser = ParserManager.get_reasoning_parser( + self.reasoning_parser_cls = ParserManager.get_reasoning_parser( reasoning_parser_name=reasoning_parser ) # set up tool use @@ -330,6 +331,24 @@ class OpenAIServingChat(OpenAIServing): for the API specification. This API mimics the OpenAI Chat Completion API. """ + # Streaming response + tokenizer = self.renderer.tokenizer + assert tokenizer is not None + reasoning_parser: ReasoningParser | None = None + try: + if self.reasoning_parser_cls: + # Pass the same chat template kwargs as used in tokenization + chat_template_kwargs = self._prepare_extra_chat_template_kwargs( + request.chat_template_kwargs, + self.default_chat_template_kwargs, + ) + reasoning_parser = self.reasoning_parser_cls( + tokenizer, + chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] + ) + except RuntimeError as e: + logger.exception("Error in reasoning parser creation.") + return self.create_error_response(str(e)) result = await self.render_chat_request(request) if isinstance(result, ErrorResponse): return result @@ -427,7 +446,12 @@ class OpenAIServingChat(OpenAIServing): priority=request.priority, data_parallel_rank=data_parallel_rank, ) - + reasoning_ended = None + if reasoning_parser: + reasoning_ended = reasoning_parser.is_reasoning_end( + engine_request.prompt_token_ids or [] # type: ignore[attr-defined] + ) + engine_request.reasoning_ended = reasoning_ended generator = self.engine_client.generate( engine_request, sampling_params, @@ -447,10 +471,6 @@ class OpenAIServingChat(OpenAIServing): assert len(generators) == 1 (result_generator,) = generators - # Streaming response - tokenizer = self.renderer.tokenizer - assert tokenizer is not None - if request.stream: return self.chat_completion_stream_generator( request, @@ -460,6 +480,7 @@ class OpenAIServingChat(OpenAIServing): conversation, tokenizer, request_metadata, + reasoning_parser, ) try: @@ -471,6 +492,7 @@ class OpenAIServingChat(OpenAIServing): conversation, tokenizer, request_metadata, + reasoning_parser, ) except GenerationError as e: return self._convert_generation_error_to_response(e) @@ -630,6 +652,7 @@ class OpenAIServingChat(OpenAIServing): conversation: list[ConversationMessage], tokenizer: TokenizerLike, request_metadata: RequestResponseMetadata, + reasoning_parser: ReasoningParser | None = None, ) -> AsyncGenerator[str, None]: from vllm.tokenizers.mistral import MistralTokenizer @@ -673,7 +696,7 @@ class OpenAIServingChat(OpenAIServing): # Only one of these will be used, thus previous_texts and # all_previous_token_ids will not be used twice in the same iteration. - if tool_choice_auto or self.reasoning_parser: + if tool_choice_auto or reasoning_parser: # These are only required in "auto" tool choice case all_previous_token_ids = [[]] * num_choices # For reasoning parser and tool call all enabled @@ -683,28 +706,6 @@ class OpenAIServingChat(OpenAIServing): else: all_previous_token_ids = None - try: - if self.reasoning_parser: - if tokenizer is None: - raise ValueError( - "Tokenizer not available when `skip_tokenizer_init=True`" - ) - - # Pass the same chat template kwargs as used in tokenization - chat_template_kwargs = self._prepare_extra_chat_template_kwargs( - request.chat_template_kwargs, - self.default_chat_template_kwargs, - ) - reasoning_parser = self.reasoning_parser( - tokenizer, - chat_template_kwargs=chat_template_kwargs or {}, # type: ignore[call-arg] - ) - except RuntimeError as e: - logger.exception("Error in reasoning parser creation.") - data = self.create_streaming_error_response(str(e)) - yield f"data: {data}\n\n" - yield "data: [DONE]\n\n" - return # Prepare the tool parser if it's needed try: if tool_choice_auto and self.tool_parser: @@ -826,7 +827,7 @@ class OpenAIServingChat(OpenAIServing): tool_parser = tool_parsers[i] if ( - self.reasoning_parser + reasoning_parser and res.prompt_token_ids and prompt_is_reasoning_end_arr[i] is None ): @@ -888,7 +889,7 @@ class OpenAIServingChat(OpenAIServing): delta_message: DeltaMessage | None # just update previous_texts and previous_token_ids - if tool_choice_auto or self.reasoning_parser: + if tool_choice_auto or reasoning_parser: assert previous_texts is not None assert all_previous_token_ids is not None previous_text = previous_texts[i] @@ -915,7 +916,7 @@ class OpenAIServingChat(OpenAIServing): # handle streaming deltas for tools with named tool_choice elif tool_choice_function_name: if ( - self.reasoning_parser + reasoning_parser and not reasoning_end_arr[i] and not reasoning_parser.is_reasoning_end( previous_token_ids @@ -952,7 +953,7 @@ class OpenAIServingChat(OpenAIServing): current_text = "" else: # Just to add remaining `content` - if self.reasoning_parser: + if reasoning_parser: delta_text = previous_text + delta_text current_text = "" @@ -998,13 +999,13 @@ class OpenAIServingChat(OpenAIServing): output_token_ids = as_list(output.token_ids) if ( - self.reasoning_parser is not None + reasoning_parser is not None and not reasoning_end_arr[i] and prompt_is_reasoning_end_arr[i] ): reasoning_end_arr[i] = True - if self.reasoning_parser and not reasoning_end_arr[i]: + if reasoning_parser and not reasoning_end_arr[i]: delta_message = ( reasoning_parser.extract_reasoning_streaming( previous_text, @@ -1047,9 +1048,8 @@ class OpenAIServingChat(OpenAIServing): # handle streaming deltas for tools with "auto" tool choice # and reasoning parser - elif tool_choice_auto and self.reasoning_parser: + elif tool_choice_auto and reasoning_parser: assert tool_parser is not None - assert reasoning_parser is not None assert added_content_delta_arr is not None assert reasoning_end_arr is not None output_token_ids = as_list(output.token_ids) @@ -1130,7 +1130,7 @@ class OpenAIServingChat(OpenAIServing): tools_streamed[i] = True # when only reasoning - elif self.reasoning_parser: + elif reasoning_parser: delta_message = reasoning_parser.extract_reasoning_streaming( previous_text, current_text, @@ -1144,9 +1144,7 @@ class OpenAIServingChat(OpenAIServing): delta_message = DeltaMessage(content=delta_text) # update the previous values for the next iteration - if ( - tool_choice_auto or self.reasoning_parser - ) and not self.use_harmony: + if (tool_choice_auto or reasoning_parser) and not self.use_harmony: assert previous_texts is not None assert all_previous_token_ids is not None previous_texts[i] = current_text @@ -1400,6 +1398,7 @@ class OpenAIServingChat(OpenAIServing): conversation: list[ConversationMessage], tokenizer: TokenizerLike, request_metadata: RequestResponseMetadata, + reasoning_parser: ReasoningParser | None = None, ) -> ErrorResponse | ChatCompletionResponse: from vllm.tokenizers.mistral import MistralTokenizer @@ -1494,25 +1493,7 @@ class OpenAIServingChat(OpenAIServing): choices.append(choice_data) continue - if self.reasoning_parser: - try: - if tokenizer is None: - raise ValueError( - "Tokenizer not available when `skip_tokenizer_init=True`" - ) - - # Pass the same chat template kwargs as used in tokenization - chat_template_kwargs = self._prepare_extra_chat_template_kwargs( - request.chat_template_kwargs, - self.default_chat_template_kwargs, - ) - reasoning_parser = self.reasoning_parser( - tokenizer, - chat_template_kwargs=chat_template_kwargs, # type: ignore[call-arg] - ) - except RuntimeError as e: - logger.exception("Error in reasoning parser creation.") - return self.create_error_response(str(e)) + if reasoning_parser: # If the reasoning parser is enabled, # tool calls are extracted exclusively from the content. reasoning, content = reasoning_parser.extract_reasoning( diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 5328a6735..f5d8ce1ff 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -83,6 +83,8 @@ class EngineCoreRequest( # Used in outputs and to support abort(req_id, internal=False). external_req_id: str | None = None + reasoning_ended: bool | None = None + @property def params(self) -> SamplingParams | PoolingParams: """Return the processed params (sampling or pooling).""" diff --git a/vllm/v1/request.py b/vllm/v1/request.py index e9d3df442..8e3684d3c 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -74,6 +74,7 @@ class Request: trace_headers: Mapping[str, str] | None = None, block_hasher: Callable[["Request"], list["BlockHash"]] | None = None, resumable: bool = False, + reasoning_ended: bool | None = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -86,6 +87,8 @@ class Request: self.structured_output_request = StructuredOutputRequest.from_sampling_params( sampling_params ) + if self.structured_output_request is not None: + self.structured_output_request.reasoning_ended = reasoning_ended self.arrival_time = arrival_time if arrival_time is not None else time.time() self.status = RequestStatus.WAITING @@ -195,6 +198,7 @@ class Request: trace_headers=request.trace_headers, block_hasher=block_hasher, resumable=request.resumable, + reasoning_ended=request.reasoning_ended, ) def append_output_token_ids( diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 9b86d69a7..921bee6a6 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -284,12 +284,15 @@ class StructuredOutputManager: # NOTE (Hanchen) if enable_in_reasoning is True, it means that # the model needs to be constrained in reasoning. So we should always # enable the bitmask filling. - if self.reasoner is not None: if self.enable_in_reasoning: return True assert request.structured_output_request is not None if request.structured_output_request.reasoning_ended is None: + # This should be removed here, but since `openai_gptoss` + # is an independent code path, it is kept for now. + # After unifying the `openai_gptoss` and non-`openai_gptoss` styles, + # it can be removed. request.structured_output_request.reasoning_ended = ( self.reasoner.is_reasoning_end(request.prompt_token_ids or []) )