diff --git a/tests/reasoning/test_nemotron_v3_reasoning_parser.py b/tests/reasoning/test_nemotron_v3_reasoning_parser.py new file mode 100644 index 000000000..3fe383a08 --- /dev/null +++ b/tests/reasoning/test_nemotron_v3_reasoning_parser.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import TypedDict + +import pytest +import regex as re + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +parser_name = "nemotron_v3" + + +class ReasoningCase(TypedDict): + output: str + reasoning: str | None + content: str | None + + +class FakeNemotronTokenizer: + def __init__(self): + self._vocab = { + "": 1, + "": 2, + } + self._pattern = re.compile(r"(|)") + + def get_vocab(self) -> dict[str, int]: + return self._vocab + + def tokenize(self, text: str) -> list[str]: + tokens: list[str] = [] + for part in self._pattern.split(text): + if part: + tokens.append(part) + return tokens + + def convert_tokens_to_string(self, tokens: list[str]) -> str: + return "".join(tokens) + + +@pytest.fixture +def tokenizer(): + return FakeNemotronTokenizer() + + +@pytest.mark.parametrize( + "streaming,param_dict", + [ + pytest.param( + False, + { + "output": "This is a reasoning sectionThis is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", + }, + id="without_start_token", + ), + pytest.param( + True, + { + "output": "This is a reasoning sectionThis is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", + }, + id="without_start_token_streaming", + ), + pytest.param( + False, + { + "output": "This is a reasoning sectionThis is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", + }, + id="with_start_token", + ), + pytest.param( + True, + { + "output": "This is a reasoning sectionThis is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", + }, + id="with_start_token_streaming", + ), + ], +) +def test_nemotron_v3_reasoning( + tokenizer: FakeNemotronTokenizer, + streaming: bool, + param_dict: ReasoningCase, +): + output = tokenizer.tokenize(param_dict["output"]) + model_output = [tokenizer.convert_tokens_to_string([token]) for token in output] + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + tokenizer + ) + + reasoning, content = run_reasoning_extraction( + parser, model_output, streaming=streaming + ) + + assert reasoning == param_dict["reasoning"] + assert content == param_dict["content"] + + +def test_nemotron_v3_without_thinking_returns_content( + tokenizer: FakeNemotronTokenizer, +): + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(tokenizer) + request = ChatCompletionRequest( + model="test-model", + messages=[], + chat_template_kwargs={"enable_thinking": False}, + ) + + reasoning, content = run_reasoning_extraction( + parser, + ["This is plain content"], + request=request, + streaming=False, + ) + + assert reasoning is None + assert content == "This is plain content" + + +def test_nemotron_v3_with_thinking_keeps_truncated_reasoning( + tokenizer: FakeNemotronTokenizer, +): + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_name) + parser = parser_cls(tokenizer) + request = ChatCompletionRequest( + model="test-model", + messages=[], + chat_template_kwargs={"enable_thinking": True}, + ) + + reasoning, content = run_reasoning_extraction( + parser, + ["This is truncated reasoning"], + request=request, + streaming=False, + ) + + assert reasoning == "This is truncated reasoning" + assert content is None diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index df75e8584..8c78db6f1 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -68,6 +68,10 @@ _REASONING_PARSERS_TO_REGISTER = { "mistral_reasoning_parser", "MistralReasoningParser", ), + "nemotron_v3": ( + "nemotron_v3_reasoning_parser", + "NemotronV3ReasoningParser", + ), "olmo3": ( "olmo3_reasoning_parser", "Olmo3ReasoningParser", diff --git a/vllm/reasoning/nemotron_v3_reasoning_parser.py b/vllm/reasoning/nemotron_v3_reasoning_parser.py new file mode 100644 index 000000000..a929793bf --- /dev/null +++ b/vllm/reasoning/nemotron_v3_reasoning_parser.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.responses.protocol import ( + ResponsesRequest, +) +from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser + + +class NemotronV3ReasoningParser(DeepSeekR1ReasoningParser): + """ + Reasoning parser for Nemotron V3 models. + """ + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest | ResponsesRequest + ) -> tuple[str | None, str | None]: + reasoning_content, final_content = super().extract_reasoning( + model_output, request + ) + chat_template_kwargs = getattr(request, "chat_template_kwargs", None) + + if ( + chat_template_kwargs + and chat_template_kwargs.get("enable_thinking") is False + and final_content is None + ): + reasoning_content, final_content = final_content, reasoning_content + + return reasoning_content, final_content