diff --git a/tests/reasoning/test_kimi_k2_reasoning_parser.py b/tests/reasoning/test_kimi_k2_reasoning_parser.py new file mode 100644 index 000000000..0f80bb885 --- /dev/null +++ b/tests/reasoning/test_kimi_k2_reasoning_parser.py @@ -0,0 +1,155 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.reasoning.identity_reasoning_parser import IdentityReasoningParser +from vllm.reasoning.kimi_k2_reasoning_parser import KimiK2ReasoningParser +from vllm.tokenizers import get_tokenizer + +REASONING_MODEL_NAME = "moonshotai/Kimi-K2.5" + + +@pytest.fixture(scope="module") +def kimi_k2_tokenizer(): + return get_tokenizer(tokenizer_name=REASONING_MODEL_NAME, trust_remote_code=True) + + +def test_parser_selection_thinking_enabled(kimi_k2_tokenizer): + parser = KimiK2ReasoningParser( + kimi_k2_tokenizer, chat_template_kwargs={"thinking": True} + ) + assert parser._identity_parser is None + + +def test_parser_selection_thinking_disabled(kimi_k2_tokenizer): + parser = KimiK2ReasoningParser( + kimi_k2_tokenizer, chat_template_kwargs={"thinking": False} + ) + assert isinstance(parser._identity_parser, IdentityReasoningParser) + + +def test_extract_reasoning_with_think_tags(kimi_k2_tokenizer): + parser = KimiK2ReasoningParser(kimi_k2_tokenizer) + request = ChatCompletionRequest(model="test-model", messages=[], temperature=1.0) + + reasoning, content = parser.extract_reasoning( + "step by step reasoningfinal answer", request + ) + assert reasoning == "step by step reasoning" + assert content == "final answer" + + +def test_extract_reasoning_empty_thinking(kimi_k2_tokenizer): + parser = KimiK2ReasoningParser(kimi_k2_tokenizer) + request = ChatCompletionRequest(model="test-model", messages=[], temperature=1.0) + + reasoning, content = parser.extract_reasoning( + "final answer", request + ) + assert reasoning == "" + assert content == "final answer" + + +def test_extract_reasoning_implicit_start(kimi_k2_tokenizer): + """When there's no tag, everything is treated as reasoning.""" + parser = KimiK2ReasoningParser(kimi_k2_tokenizer) + request = ChatCompletionRequest(model="test-model", messages=[], temperature=1.0) + + reasoning, content = parser.extract_reasoning( + "implicit reasoning with no tags", request + ) + assert reasoning == "implicit reasoning with no tags" + assert content is None + + +def test_extract_reasoning_tool_section_ends_reasoning(kimi_k2_tokenizer): + """<|tool_calls_section_begin|> implicitly ends reasoning.""" + parser = KimiK2ReasoningParser(kimi_k2_tokenizer) + request = ChatCompletionRequest(model="test-model", messages=[], temperature=1.0) + + text = "some reasoning<|tool_calls_section_begin|>tool call data" + reasoning, content = parser.extract_reasoning(text, request) + assert reasoning == "some reasoning" + assert content == "<|tool_calls_section_begin|>tool call data" + + +def test_streaming_reasoning_then_content(kimi_k2_tokenizer): + """Token-by-token streaming: reasoning tokens then content after .""" + parser = KimiK2ReasoningParser(kimi_k2_tokenizer) + + think_id = parser._start_token_id + end_think_id = parser._end_token_id + # Use a real token ID from the tokenizer for regular content + regular_id = kimi_k2_tokenizer.encode("hello", add_special_tokens=False)[0] + + # First token: — single special token should be skipped + result = parser.extract_reasoning_streaming( + previous_text="", + current_text="", + delta_text="", + previous_token_ids=[], + current_token_ids=[think_id], + delta_token_ids=[think_id], + ) + assert result is None + + # Reasoning token + result = parser.extract_reasoning_streaming( + previous_text="", + current_text="step one", + delta_text="step one", + previous_token_ids=[think_id], + current_token_ids=[think_id, regular_id], + delta_token_ids=[regular_id], + ) + assert isinstance(result, DeltaMessage) + assert result.reasoning == "step one" + assert result.content is None + + # End token as single token — should be skipped + result = parser.extract_reasoning_streaming( + previous_text="step one", + current_text="step one", + delta_text="", + previous_token_ids=[think_id, regular_id], + current_token_ids=[think_id, regular_id, end_think_id], + delta_token_ids=[end_think_id], + ) + assert result is None + + # Content after + content_id = kimi_k2_tokenizer.encode("world", add_special_tokens=False)[0] + result = parser.extract_reasoning_streaming( + previous_text="step one", + current_text="step oneanswer", + delta_text="answer", + previous_token_ids=[think_id, regular_id, end_think_id], + current_token_ids=[think_id, regular_id, end_think_id, content_id], + delta_token_ids=[content_id], + ) + assert isinstance(result, DeltaMessage) + assert result.content == "answer" + + +def test_streaming_tool_section_ends_reasoning(kimi_k2_tokenizer): + """<|tool_calls_section_begin|> in delta ends reasoning during streaming.""" + parser = KimiK2ReasoningParser(kimi_k2_tokenizer) + + think_id = parser._start_token_id + tool_begin_id = parser._tool_section_start_token_id + regular_id = kimi_k2_tokenizer.encode("hello", add_special_tokens=False)[0] + + # Tool section token arrives — should transition from reasoning to content + result = parser.extract_reasoning_streaming( + previous_text="thinking", + current_text="thinking<|tool_calls_section_begin|>", + delta_text="<|tool_calls_section_begin|>", + previous_token_ids=[think_id, regular_id], + current_token_ids=[think_id, regular_id, tool_begin_id], + delta_token_ids=[tool_begin_id], + ) + assert isinstance(result, DeltaMessage) + assert result.content == "<|tool_calls_section_begin|>" diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4839fc80c..6af762991 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1660,6 +1660,20 @@ def get_history_tool_calls_cnt(conversation: list[ConversationMessage]): return idx +_KIMI_MODEL_TYPES = ("kimi_k2", "kimi_k25") + + +def get_tool_call_id_type(model_config: ModelConfig) -> str: + """Return the tool-call ID type for a given model configuration.""" + hf_overrides = getattr(model_config, "hf_overrides", None) + if model_config.hf_text_config.model_type in _KIMI_MODEL_TYPES or ( + isinstance(hf_overrides, dict) + and hf_overrides.get("model_type") in _KIMI_MODEL_TYPES + ): + return "kimi_k2" + return "random" + + def make_tool_call_id(id_type: str = "random", func_name=None, idx=None): if id_type == "kimi_k2": return f"functions.{func_name}:{idx}" diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index ad7982b61..62a0192e7 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -19,6 +19,7 @@ from vllm.entrypoints.chat_utils import ( ChatTemplateContentFormatOption, ConversationMessage, get_history_tool_calls_cnt, + get_tool_call_id_type, make_tool_call_id, ) from vllm.entrypoints.logger import RequestLogger @@ -152,15 +153,7 @@ class OpenAIServingChat(OpenAIServing): get_stop_tokens_for_assistant_actions() ) - # Handle tool call ID type for Kimi K2 (supporting test mocking via overrides) - hf_overrides = getattr(self.model_config, "hf_overrides", None) - if self.model_config.hf_text_config.model_type == "kimi_k2" or ( - isinstance(hf_overrides, dict) - and hf_overrides.get("model_type") == "kimi_k2" - ): - self.tool_call_id_type = "kimi_k2" - else: - self.tool_call_id_type = "random" + self.tool_call_id_type = get_tool_call_id_type(self.model_config) # NOTE(woosuk): While OpenAI's chat completion API supports browsing # for some models, currently vLLM doesn't support it. Please use the diff --git a/vllm/entrypoints/openai/responses/serving.py b/vllm/entrypoints/openai/responses/serving.py index b2428e97e..574282c4c 100644 --- a/vllm/entrypoints/openai/responses/serving.py +++ b/vllm/entrypoints/openai/responses/serving.py @@ -46,6 +46,7 @@ from vllm.engine.protocol import EngineClient from vllm.entrypoints.chat_utils import ( ChatCompletionMessageParam, ChatTemplateContentFormatOption, + get_tool_call_id_type, ) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.mcp.tool_server import ToolServer @@ -241,15 +242,7 @@ class OpenAIServingResponses(OpenAIServing): get_stop_tokens_for_assistant_actions() ) - # Handle tool call ID type for Kimi K2 (supporting test mocking via overrides) - hf_overrides = getattr(self.model_config, "hf_overrides", None) - if self.model_config.hf_text_config.model_type == "kimi_k2" or ( - isinstance(hf_overrides, dict) - and hf_overrides.get("model_type") == "kimi_k2" - ): - self.tool_call_id_type = "kimi_k2" - else: - self.tool_call_id_type = "random" + self.tool_call_id_type = get_tool_call_id_type(self.model_config) self.enable_auto_tools = enable_auto_tools # HACK(woosuk): This is a hack. We should use a better store.