we need to forward the context using the old way

This commit is contained in:
2026-04-14 06:18:32 +00:00
parent d0c9c5c482
commit fcf8fd134e

View File

@@ -16,12 +16,14 @@ what was previously sent.
Key changes vs. the upstream token-count parser:
1. No token-count state machine — the parser is stateless w.r.t.
how many tokens arrived per step.
2. _extract_content() uses partial_tag_overlap to safely handle
section-start tags split across chunk boundaries.
2. Content forwarding uses delta_text (not re-parsed current_text)
so reasoning text is never re-emitted as content.
3. _extract_tool_call_regions() finds both complete and incomplete
tool-call blocks, enabling argument streaming.
4. _compute_args_diff() emits only newly-added characters.
5. Handles singular/plural section marker variants.
6. Returns empty deltas inside open sections to keep the stream
alive while tool call tokens are still arriving.
Drop-in replacement: same class name, same interface.
@@ -133,7 +135,6 @@ class KimiK2ToolParser(ToolParser):
)
# ----- Streaming state (reset per request) -----
self._sent_content_idx: int = 0
self._tool_call_ids: list[str] = []
self.streamed_args_for_tool: list[str] = []
self.prev_tool_call_arr: list[dict[str, Any]] = []
@@ -288,82 +289,11 @@ class KimiK2ToolParser(ToolParser):
# ------------------------------------------------------------------
def _reset_streaming_state(self) -> None:
self._sent_content_idx = 0
self._tool_call_ids.clear()
self.streamed_args_for_tool.clear()
self.prev_tool_call_arr.clear()
self.current_tool_id = -1
def _get_earliest_section_start_tag(self) -> str:
"""Return the shortest section-start variant (used for overlap
checking — the shortest tag has the smallest possible overlap,
so we check against the longest to be safe)."""
return max(
self.tool_calls_section_start_variants, key=len
)
def _extract_content(self, current_text: str) -> str | None:
"""Return any non-tool-section text that hasn't been sent yet.
Walks *current_text* from ``_sent_content_idx``, collecting
text outside ``<|tool_calls_section_begin|>`` …
``<|tool_calls_section_end|>`` regions. Uses
``partial_tag_overlap`` to avoid emitting bytes that might be
the start of a section tag split across chunks.
"""
content_segments: list[str] = []
pos = self._sent_content_idx
overlap_tag = self._get_earliest_section_start_tag()
while pos < len(current_text):
# Find next section-start marker from pos.
best_start = -1
best_variant_len = 0
for variant in self.tool_calls_section_start_variants:
idx = current_text.find(variant, pos)
if idx != -1 and (best_start == -1 or idx < best_start):
best_start = idx
best_variant_len = len(variant)
if best_start == -1:
# No more section regions — emit tail minus overlap.
tail = current_text[pos:]
overlap = partial_tag_overlap(tail, overlap_tag)
sendable = tail[: len(tail) - overlap] if overlap else tail
if sendable:
content_segments.append(sendable)
pos = len(current_text) - overlap
break
# Text before the section start is content.
if best_start > pos:
content_segments.append(current_text[pos:best_start])
# Skip past the section region.
inner_start = best_start + best_variant_len
# Find matching section-end.
best_end = -1
best_end_variant_len = 0
for variant in self.tool_calls_section_end_variants:
idx = current_text.find(variant, inner_start)
if idx != -1 and (best_end == -1 or idx < best_end):
best_end = idx
best_end_variant_len = len(variant)
if best_end != -1:
pos = best_end + best_end_variant_len
else:
# Section still open — park cursor, stop.
pos = best_start
break
if content_segments:
self._sent_content_idx = pos
return "".join(content_segments)
if pos > self._sent_content_idx:
self._sent_content_idx = pos
return None
def _extract_tool_call_regions(
self, text: str
) -> list[tuple[str, bool]]:
@@ -499,16 +429,15 @@ class KimiK2ToolParser(ToolParser):
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""Extract tool calls from streaming output using re-parse-and-diff.
"""Extract tool calls from streaming output.
On every call we:
1. Re-scan *current_text* for content outside tool-call sections.
2. Find all ``<|tool_call_begin|>`` regions (complete + partial).
3. Parse each region for tool ID and arguments.
4. Diff arguments against previous state, emit deltas.
Because the entire text is re-parsed each time, the result is
correct regardless of how many tokens arrived in this step.
Hybrid approach:
- **Content forwarding** uses ``delta_text`` (same as the
original parser) so we never re-emit text that the reasoning
parser already handled.
- **Tool call detection** re-parses ``current_text`` on every
call (the re-parse-and-diff approach) so it's agnostic to
how many tokens arrived per step — robust against MTP.
"""
logger.debug("delta_text: %s", delta_text)
logger.debug("delta_token_ids: %s", delta_token_ids)
@@ -521,10 +450,15 @@ class KimiK2ToolParser(ToolParser):
if not self._tools_enabled(request):
return DeltaMessage(content=delta_text) if delta_text else None
# 1. Extract any content outside tool-call sections.
content = self._extract_content(current_text)
# ── Determine section state from full text (MTP-safe) ──
inner_start, inner_end = self._find_section_start_end(current_text)
in_open_section = inner_start != -1 and inner_end == -1
# 2. Find all tool-call regions.
# Was the section already open in previous_text?
prev_inner_start, _ = self._find_section_start_end(previous_text)
section_existed_before = prev_inner_start != -1
# ── Re-parse tool calls from current_text (MTP-safe) ──
regions = self._extract_tool_call_regions(current_text)
tool_call_deltas: list[DeltaToolCall] = []
@@ -567,30 +501,54 @@ class KimiK2ToolParser(ToolParser):
if regions:
self.current_tool_id = len(regions) - 1
# 3. Return a delta if we have content or tool-call updates.
if content or tool_call_deltas:
kwargs: dict[str, Any] = {}
if content:
kwargs["content"] = content
if tool_call_deltas:
kwargs["tool_calls"] = tool_call_deltas
return DeltaMessage(**kwargs)
# ── Emit results ──
# If we're inside an OPEN tool section (section start found but
# not yet closed), return an empty delta to keep the stream alive
# while waiting for tool call tokens to arrive. Returning None
# here would cause the serving layer to think there's nothing to
# send, potentially terminating the stream before the tool call
# is complete.
inner_start, inner_end = self._find_section_start_end(current_text)
if inner_start != -1 and inner_end == -1:
# Section is open — still generating tool calls
# Case 1: We have tool call updates — emit them.
if tool_call_deltas:
return DeltaMessage(tool_calls=tool_call_deltas)
# Case 2: No tool section has started yet — forward delta_text
# as content. The reasoning parser handles the reasoning/content
# split; we just pass through whatever delta the serving layer
# gave us.
if inner_start == -1:
return DeltaMessage(content=delta_text) if delta_text else None
# Case 3: The section just appeared in this delta. Extract any
# content that came before the section marker in this delta
# (e.g. "Let me check.<|tool_calls_section_begin|>").
if not section_existed_before:
section_start_in_text = self._find_section_start(current_text)
pre_section = current_text[len(previous_text):section_start_in_text]
if pre_section.strip():
return DeltaMessage(content=pre_section)
return DeltaMessage(content="")
# No tool section, no content — forward delta as content.
# This handles the case where the tool parser is called during
# normal content generation (before any tool calls).
if delta_text:
return DeltaMessage(content=delta_text)
# Case 4: Inside an open tool section but tool calls aren't
# parseable yet — emit empty delta to keep the stream alive.
if in_open_section:
return DeltaMessage(content="")
# Case 5: Section is closed and we're past it — forward any
# new content that appeared after the section end marker.
if inner_end != -1:
for variant in self.tool_calls_section_end_variants:
end_marker_pos = current_text.find(variant, inner_start)
if end_marker_pos != -1:
after_section = current_text[
end_marker_pos + len(variant):
]
# Only emit what's new (not previously seen)
prev_after_len = 0
prev_end_pos = previous_text.find(variant)
if prev_end_pos != -1:
prev_after_len = len(
previous_text[prev_end_pos + len(variant):]
)
new_after = after_section[prev_after_len:]
if new_after:
return DeltaMessage(content=new_after)
break
return DeltaMessage(content="") if delta_text else None
return None