we need to forward the context using the old way
This commit is contained in:
@@ -16,12 +16,14 @@ what was previously sent.
|
||||
Key changes vs. the upstream token-count parser:
|
||||
1. No token-count state machine — the parser is stateless w.r.t.
|
||||
how many tokens arrived per step.
|
||||
2. _extract_content() uses partial_tag_overlap to safely handle
|
||||
section-start tags split across chunk boundaries.
|
||||
2. Content forwarding uses delta_text (not re-parsed current_text)
|
||||
so reasoning text is never re-emitted as content.
|
||||
3. _extract_tool_call_regions() finds both complete and incomplete
|
||||
tool-call blocks, enabling argument streaming.
|
||||
4. _compute_args_diff() emits only newly-added characters.
|
||||
5. Handles singular/plural section marker variants.
|
||||
6. Returns empty deltas inside open sections to keep the stream
|
||||
alive while tool call tokens are still arriving.
|
||||
|
||||
Drop-in replacement: same class name, same interface.
|
||||
|
||||
@@ -133,7 +135,6 @@ class KimiK2ToolParser(ToolParser):
|
||||
)
|
||||
|
||||
# ----- Streaming state (reset per request) -----
|
||||
self._sent_content_idx: int = 0
|
||||
self._tool_call_ids: list[str] = []
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||
@@ -288,82 +289,11 @@ class KimiK2ToolParser(ToolParser):
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
self._sent_content_idx = 0
|
||||
self._tool_call_ids.clear()
|
||||
self.streamed_args_for_tool.clear()
|
||||
self.prev_tool_call_arr.clear()
|
||||
self.current_tool_id = -1
|
||||
|
||||
def _get_earliest_section_start_tag(self) -> str:
|
||||
"""Return the shortest section-start variant (used for overlap
|
||||
checking — the shortest tag has the smallest possible overlap,
|
||||
so we check against the longest to be safe)."""
|
||||
return max(
|
||||
self.tool_calls_section_start_variants, key=len
|
||||
)
|
||||
|
||||
def _extract_content(self, current_text: str) -> str | None:
|
||||
"""Return any non-tool-section text that hasn't been sent yet.
|
||||
|
||||
Walks *current_text* from ``_sent_content_idx``, collecting
|
||||
text outside ``<|tool_calls_section_begin|>`` …
|
||||
``<|tool_calls_section_end|>`` regions. Uses
|
||||
``partial_tag_overlap`` to avoid emitting bytes that might be
|
||||
the start of a section tag split across chunks.
|
||||
"""
|
||||
content_segments: list[str] = []
|
||||
pos = self._sent_content_idx
|
||||
overlap_tag = self._get_earliest_section_start_tag()
|
||||
|
||||
while pos < len(current_text):
|
||||
# Find next section-start marker from pos.
|
||||
best_start = -1
|
||||
best_variant_len = 0
|
||||
for variant in self.tool_calls_section_start_variants:
|
||||
idx = current_text.find(variant, pos)
|
||||
if idx != -1 and (best_start == -1 or idx < best_start):
|
||||
best_start = idx
|
||||
best_variant_len = len(variant)
|
||||
|
||||
if best_start == -1:
|
||||
# No more section regions — emit tail minus overlap.
|
||||
tail = current_text[pos:]
|
||||
overlap = partial_tag_overlap(tail, overlap_tag)
|
||||
sendable = tail[: len(tail) - overlap] if overlap else tail
|
||||
if sendable:
|
||||
content_segments.append(sendable)
|
||||
pos = len(current_text) - overlap
|
||||
break
|
||||
|
||||
# Text before the section start is content.
|
||||
if best_start > pos:
|
||||
content_segments.append(current_text[pos:best_start])
|
||||
|
||||
# Skip past the section region.
|
||||
inner_start = best_start + best_variant_len
|
||||
# Find matching section-end.
|
||||
best_end = -1
|
||||
best_end_variant_len = 0
|
||||
for variant in self.tool_calls_section_end_variants:
|
||||
idx = current_text.find(variant, inner_start)
|
||||
if idx != -1 and (best_end == -1 or idx < best_end):
|
||||
best_end = idx
|
||||
best_end_variant_len = len(variant)
|
||||
|
||||
if best_end != -1:
|
||||
pos = best_end + best_end_variant_len
|
||||
else:
|
||||
# Section still open — park cursor, stop.
|
||||
pos = best_start
|
||||
break
|
||||
|
||||
if content_segments:
|
||||
self._sent_content_idx = pos
|
||||
return "".join(content_segments)
|
||||
if pos > self._sent_content_idx:
|
||||
self._sent_content_idx = pos
|
||||
return None
|
||||
|
||||
def _extract_tool_call_regions(
|
||||
self, text: str
|
||||
) -> list[tuple[str, bool]]:
|
||||
@@ -499,16 +429,15 @@ class KimiK2ToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract tool calls from streaming output using re-parse-and-diff.
|
||||
"""Extract tool calls from streaming output.
|
||||
|
||||
On every call we:
|
||||
1. Re-scan *current_text* for content outside tool-call sections.
|
||||
2. Find all ``<|tool_call_begin|>`` regions (complete + partial).
|
||||
3. Parse each region for tool ID and arguments.
|
||||
4. Diff arguments against previous state, emit deltas.
|
||||
|
||||
Because the entire text is re-parsed each time, the result is
|
||||
correct regardless of how many tokens arrived in this step.
|
||||
Hybrid approach:
|
||||
- **Content forwarding** uses ``delta_text`` (same as the
|
||||
original parser) so we never re-emit text that the reasoning
|
||||
parser already handled.
|
||||
- **Tool call detection** re-parses ``current_text`` on every
|
||||
call (the re-parse-and-diff approach) so it's agnostic to
|
||||
how many tokens arrived per step — robust against MTP.
|
||||
"""
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||
@@ -521,10 +450,15 @@ class KimiK2ToolParser(ToolParser):
|
||||
if not self._tools_enabled(request):
|
||||
return DeltaMessage(content=delta_text) if delta_text else None
|
||||
|
||||
# 1. Extract any content outside tool-call sections.
|
||||
content = self._extract_content(current_text)
|
||||
# ── Determine section state from full text (MTP-safe) ──
|
||||
inner_start, inner_end = self._find_section_start_end(current_text)
|
||||
in_open_section = inner_start != -1 and inner_end == -1
|
||||
|
||||
# 2. Find all tool-call regions.
|
||||
# Was the section already open in previous_text?
|
||||
prev_inner_start, _ = self._find_section_start_end(previous_text)
|
||||
section_existed_before = prev_inner_start != -1
|
||||
|
||||
# ── Re-parse tool calls from current_text (MTP-safe) ──
|
||||
regions = self._extract_tool_call_regions(current_text)
|
||||
tool_call_deltas: list[DeltaToolCall] = []
|
||||
|
||||
@@ -567,30 +501,54 @@ class KimiK2ToolParser(ToolParser):
|
||||
if regions:
|
||||
self.current_tool_id = len(regions) - 1
|
||||
|
||||
# 3. Return a delta if we have content or tool-call updates.
|
||||
if content or tool_call_deltas:
|
||||
kwargs: dict[str, Any] = {}
|
||||
if content:
|
||||
kwargs["content"] = content
|
||||
if tool_call_deltas:
|
||||
kwargs["tool_calls"] = tool_call_deltas
|
||||
return DeltaMessage(**kwargs)
|
||||
# ── Emit results ──
|
||||
|
||||
# If we're inside an OPEN tool section (section start found but
|
||||
# not yet closed), return an empty delta to keep the stream alive
|
||||
# while waiting for tool call tokens to arrive. Returning None
|
||||
# here would cause the serving layer to think there's nothing to
|
||||
# send, potentially terminating the stream before the tool call
|
||||
# is complete.
|
||||
inner_start, inner_end = self._find_section_start_end(current_text)
|
||||
if inner_start != -1 and inner_end == -1:
|
||||
# Section is open — still generating tool calls
|
||||
# Case 1: We have tool call updates — emit them.
|
||||
if tool_call_deltas:
|
||||
return DeltaMessage(tool_calls=tool_call_deltas)
|
||||
|
||||
# Case 2: No tool section has started yet — forward delta_text
|
||||
# as content. The reasoning parser handles the reasoning/content
|
||||
# split; we just pass through whatever delta the serving layer
|
||||
# gave us.
|
||||
if inner_start == -1:
|
||||
return DeltaMessage(content=delta_text) if delta_text else None
|
||||
|
||||
# Case 3: The section just appeared in this delta. Extract any
|
||||
# content that came before the section marker in this delta
|
||||
# (e.g. "Let me check.<|tool_calls_section_begin|>").
|
||||
if not section_existed_before:
|
||||
section_start_in_text = self._find_section_start(current_text)
|
||||
pre_section = current_text[len(previous_text):section_start_in_text]
|
||||
if pre_section.strip():
|
||||
return DeltaMessage(content=pre_section)
|
||||
return DeltaMessage(content="")
|
||||
|
||||
# No tool section, no content — forward delta as content.
|
||||
# This handles the case where the tool parser is called during
|
||||
# normal content generation (before any tool calls).
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
# Case 4: Inside an open tool section but tool calls aren't
|
||||
# parseable yet — emit empty delta to keep the stream alive.
|
||||
if in_open_section:
|
||||
return DeltaMessage(content="")
|
||||
|
||||
# Case 5: Section is closed and we're past it — forward any
|
||||
# new content that appeared after the section end marker.
|
||||
if inner_end != -1:
|
||||
for variant in self.tool_calls_section_end_variants:
|
||||
end_marker_pos = current_text.find(variant, inner_start)
|
||||
if end_marker_pos != -1:
|
||||
after_section = current_text[
|
||||
end_marker_pos + len(variant):
|
||||
]
|
||||
# Only emit what's new (not previously seen)
|
||||
prev_after_len = 0
|
||||
prev_end_pos = previous_text.find(variant)
|
||||
if prev_end_pos != -1:
|
||||
prev_after_len = len(
|
||||
previous_text[prev_end_pos + len(variant):]
|
||||
)
|
||||
new_after = after_section[prev_after_len:]
|
||||
if new_after:
|
||||
return DeltaMessage(content=new_after)
|
||||
break
|
||||
return DeltaMessage(content="") if delta_text else None
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user