diff --git a/kimi_k2_tool_parser.py b/kimi_k2_tool_parser.py index fd77030..3128bdb 100644 --- a/kimi_k2_tool_parser.py +++ b/kimi_k2_tool_parser.py @@ -16,12 +16,14 @@ what was previously sent. Key changes vs. the upstream token-count parser: 1. No token-count state machine — the parser is stateless w.r.t. how many tokens arrived per step. - 2. _extract_content() uses partial_tag_overlap to safely handle - section-start tags split across chunk boundaries. + 2. Content forwarding uses delta_text (not re-parsed current_text) + so reasoning text is never re-emitted as content. 3. _extract_tool_call_regions() finds both complete and incomplete tool-call blocks, enabling argument streaming. 4. _compute_args_diff() emits only newly-added characters. 5. Handles singular/plural section marker variants. + 6. Returns empty deltas inside open sections to keep the stream + alive while tool call tokens are still arriving. Drop-in replacement: same class name, same interface. @@ -133,7 +135,6 @@ class KimiK2ToolParser(ToolParser): ) # ----- Streaming state (reset per request) ----- - self._sent_content_idx: int = 0 self._tool_call_ids: list[str] = [] self.streamed_args_for_tool: list[str] = [] self.prev_tool_call_arr: list[dict[str, Any]] = [] @@ -288,82 +289,11 @@ class KimiK2ToolParser(ToolParser): # ------------------------------------------------------------------ def _reset_streaming_state(self) -> None: - self._sent_content_idx = 0 self._tool_call_ids.clear() self.streamed_args_for_tool.clear() self.prev_tool_call_arr.clear() self.current_tool_id = -1 - def _get_earliest_section_start_tag(self) -> str: - """Return the shortest section-start variant (used for overlap - checking — the shortest tag has the smallest possible overlap, - so we check against the longest to be safe).""" - return max( - self.tool_calls_section_start_variants, key=len - ) - - def _extract_content(self, current_text: str) -> str | None: - """Return any non-tool-section text that hasn't been sent yet. - - Walks *current_text* from ``_sent_content_idx``, collecting - text outside ``<|tool_calls_section_begin|>`` … - ``<|tool_calls_section_end|>`` regions. Uses - ``partial_tag_overlap`` to avoid emitting bytes that might be - the start of a section tag split across chunks. - """ - content_segments: list[str] = [] - pos = self._sent_content_idx - overlap_tag = self._get_earliest_section_start_tag() - - while pos < len(current_text): - # Find next section-start marker from pos. - best_start = -1 - best_variant_len = 0 - for variant in self.tool_calls_section_start_variants: - idx = current_text.find(variant, pos) - if idx != -1 and (best_start == -1 or idx < best_start): - best_start = idx - best_variant_len = len(variant) - - if best_start == -1: - # No more section regions — emit tail minus overlap. - tail = current_text[pos:] - overlap = partial_tag_overlap(tail, overlap_tag) - sendable = tail[: len(tail) - overlap] if overlap else tail - if sendable: - content_segments.append(sendable) - pos = len(current_text) - overlap - break - - # Text before the section start is content. - if best_start > pos: - content_segments.append(current_text[pos:best_start]) - - # Skip past the section region. - inner_start = best_start + best_variant_len - # Find matching section-end. - best_end = -1 - best_end_variant_len = 0 - for variant in self.tool_calls_section_end_variants: - idx = current_text.find(variant, inner_start) - if idx != -1 and (best_end == -1 or idx < best_end): - best_end = idx - best_end_variant_len = len(variant) - - if best_end != -1: - pos = best_end + best_end_variant_len - else: - # Section still open — park cursor, stop. - pos = best_start - break - - if content_segments: - self._sent_content_idx = pos - return "".join(content_segments) - if pos > self._sent_content_idx: - self._sent_content_idx = pos - return None - def _extract_tool_call_regions( self, text: str ) -> list[tuple[str, bool]]: @@ -499,16 +429,15 @@ class KimiK2ToolParser(ToolParser): delta_token_ids: Sequence[int], request: ChatCompletionRequest, ) -> DeltaMessage | None: - """Extract tool calls from streaming output using re-parse-and-diff. + """Extract tool calls from streaming output. - On every call we: - 1. Re-scan *current_text* for content outside tool-call sections. - 2. Find all ``<|tool_call_begin|>`` regions (complete + partial). - 3. Parse each region for tool ID and arguments. - 4. Diff arguments against previous state, emit deltas. - - Because the entire text is re-parsed each time, the result is - correct regardless of how many tokens arrived in this step. + Hybrid approach: + - **Content forwarding** uses ``delta_text`` (same as the + original parser) so we never re-emit text that the reasoning + parser already handled. + - **Tool call detection** re-parses ``current_text`` on every + call (the re-parse-and-diff approach) so it's agnostic to + how many tokens arrived per step — robust against MTP. """ logger.debug("delta_text: %s", delta_text) logger.debug("delta_token_ids: %s", delta_token_ids) @@ -521,10 +450,15 @@ class KimiK2ToolParser(ToolParser): if not self._tools_enabled(request): return DeltaMessage(content=delta_text) if delta_text else None - # 1. Extract any content outside tool-call sections. - content = self._extract_content(current_text) + # ── Determine section state from full text (MTP-safe) ── + inner_start, inner_end = self._find_section_start_end(current_text) + in_open_section = inner_start != -1 and inner_end == -1 - # 2. Find all tool-call regions. + # Was the section already open in previous_text? + prev_inner_start, _ = self._find_section_start_end(previous_text) + section_existed_before = prev_inner_start != -1 + + # ── Re-parse tool calls from current_text (MTP-safe) ── regions = self._extract_tool_call_regions(current_text) tool_call_deltas: list[DeltaToolCall] = [] @@ -567,30 +501,54 @@ class KimiK2ToolParser(ToolParser): if regions: self.current_tool_id = len(regions) - 1 - # 3. Return a delta if we have content or tool-call updates. - if content or tool_call_deltas: - kwargs: dict[str, Any] = {} - if content: - kwargs["content"] = content - if tool_call_deltas: - kwargs["tool_calls"] = tool_call_deltas - return DeltaMessage(**kwargs) + # ── Emit results ── - # If we're inside an OPEN tool section (section start found but - # not yet closed), return an empty delta to keep the stream alive - # while waiting for tool call tokens to arrive. Returning None - # here would cause the serving layer to think there's nothing to - # send, potentially terminating the stream before the tool call - # is complete. - inner_start, inner_end = self._find_section_start_end(current_text) - if inner_start != -1 and inner_end == -1: - # Section is open — still generating tool calls + # Case 1: We have tool call updates — emit them. + if tool_call_deltas: + return DeltaMessage(tool_calls=tool_call_deltas) + + # Case 2: No tool section has started yet — forward delta_text + # as content. The reasoning parser handles the reasoning/content + # split; we just pass through whatever delta the serving layer + # gave us. + if inner_start == -1: + return DeltaMessage(content=delta_text) if delta_text else None + + # Case 3: The section just appeared in this delta. Extract any + # content that came before the section marker in this delta + # (e.g. "Let me check.<|tool_calls_section_begin|>"). + if not section_existed_before: + section_start_in_text = self._find_section_start(current_text) + pre_section = current_text[len(previous_text):section_start_in_text] + if pre_section.strip(): + return DeltaMessage(content=pre_section) return DeltaMessage(content="") - # No tool section, no content — forward delta as content. - # This handles the case where the tool parser is called during - # normal content generation (before any tool calls). - if delta_text: - return DeltaMessage(content=delta_text) + # Case 4: Inside an open tool section but tool calls aren't + # parseable yet — emit empty delta to keep the stream alive. + if in_open_section: + return DeltaMessage(content="") + + # Case 5: Section is closed and we're past it — forward any + # new content that appeared after the section end marker. + if inner_end != -1: + for variant in self.tool_calls_section_end_variants: + end_marker_pos = current_text.find(variant, inner_start) + if end_marker_pos != -1: + after_section = current_text[ + end_marker_pos + len(variant): + ] + # Only emit what's new (not previously seen) + prev_after_len = 0 + prev_end_pos = previous_text.find(variant) + if prev_end_pos != -1: + prev_after_len = len( + previous_text[prev_end_pos + len(variant):] + ) + new_after = after_section[prev_after_len:] + if new_after: + return DeltaMessage(content=new_after) + break + return DeltaMessage(content="") if delta_text else None return None \ No newline at end of file