we need to forward the context using the old way
This commit is contained in:
@@ -16,12 +16,14 @@ what was previously sent.
|
|||||||
Key changes vs. the upstream token-count parser:
|
Key changes vs. the upstream token-count parser:
|
||||||
1. No token-count state machine — the parser is stateless w.r.t.
|
1. No token-count state machine — the parser is stateless w.r.t.
|
||||||
how many tokens arrived per step.
|
how many tokens arrived per step.
|
||||||
2. _extract_content() uses partial_tag_overlap to safely handle
|
2. Content forwarding uses delta_text (not re-parsed current_text)
|
||||||
section-start tags split across chunk boundaries.
|
so reasoning text is never re-emitted as content.
|
||||||
3. _extract_tool_call_regions() finds both complete and incomplete
|
3. _extract_tool_call_regions() finds both complete and incomplete
|
||||||
tool-call blocks, enabling argument streaming.
|
tool-call blocks, enabling argument streaming.
|
||||||
4. _compute_args_diff() emits only newly-added characters.
|
4. _compute_args_diff() emits only newly-added characters.
|
||||||
5. Handles singular/plural section marker variants.
|
5. Handles singular/plural section marker variants.
|
||||||
|
6. Returns empty deltas inside open sections to keep the stream
|
||||||
|
alive while tool call tokens are still arriving.
|
||||||
|
|
||||||
Drop-in replacement: same class name, same interface.
|
Drop-in replacement: same class name, same interface.
|
||||||
|
|
||||||
@@ -133,7 +135,6 @@ class KimiK2ToolParser(ToolParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# ----- Streaming state (reset per request) -----
|
# ----- Streaming state (reset per request) -----
|
||||||
self._sent_content_idx: int = 0
|
|
||||||
self._tool_call_ids: list[str] = []
|
self._tool_call_ids: list[str] = []
|
||||||
self.streamed_args_for_tool: list[str] = []
|
self.streamed_args_for_tool: list[str] = []
|
||||||
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
self.prev_tool_call_arr: list[dict[str, Any]] = []
|
||||||
@@ -288,82 +289,11 @@ class KimiK2ToolParser(ToolParser):
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
def _reset_streaming_state(self) -> None:
|
def _reset_streaming_state(self) -> None:
|
||||||
self._sent_content_idx = 0
|
|
||||||
self._tool_call_ids.clear()
|
self._tool_call_ids.clear()
|
||||||
self.streamed_args_for_tool.clear()
|
self.streamed_args_for_tool.clear()
|
||||||
self.prev_tool_call_arr.clear()
|
self.prev_tool_call_arr.clear()
|
||||||
self.current_tool_id = -1
|
self.current_tool_id = -1
|
||||||
|
|
||||||
def _get_earliest_section_start_tag(self) -> str:
|
|
||||||
"""Return the shortest section-start variant (used for overlap
|
|
||||||
checking — the shortest tag has the smallest possible overlap,
|
|
||||||
so we check against the longest to be safe)."""
|
|
||||||
return max(
|
|
||||||
self.tool_calls_section_start_variants, key=len
|
|
||||||
)
|
|
||||||
|
|
||||||
def _extract_content(self, current_text: str) -> str | None:
|
|
||||||
"""Return any non-tool-section text that hasn't been sent yet.
|
|
||||||
|
|
||||||
Walks *current_text* from ``_sent_content_idx``, collecting
|
|
||||||
text outside ``<|tool_calls_section_begin|>`` …
|
|
||||||
``<|tool_calls_section_end|>`` regions. Uses
|
|
||||||
``partial_tag_overlap`` to avoid emitting bytes that might be
|
|
||||||
the start of a section tag split across chunks.
|
|
||||||
"""
|
|
||||||
content_segments: list[str] = []
|
|
||||||
pos = self._sent_content_idx
|
|
||||||
overlap_tag = self._get_earliest_section_start_tag()
|
|
||||||
|
|
||||||
while pos < len(current_text):
|
|
||||||
# Find next section-start marker from pos.
|
|
||||||
best_start = -1
|
|
||||||
best_variant_len = 0
|
|
||||||
for variant in self.tool_calls_section_start_variants:
|
|
||||||
idx = current_text.find(variant, pos)
|
|
||||||
if idx != -1 and (best_start == -1 or idx < best_start):
|
|
||||||
best_start = idx
|
|
||||||
best_variant_len = len(variant)
|
|
||||||
|
|
||||||
if best_start == -1:
|
|
||||||
# No more section regions — emit tail minus overlap.
|
|
||||||
tail = current_text[pos:]
|
|
||||||
overlap = partial_tag_overlap(tail, overlap_tag)
|
|
||||||
sendable = tail[: len(tail) - overlap] if overlap else tail
|
|
||||||
if sendable:
|
|
||||||
content_segments.append(sendable)
|
|
||||||
pos = len(current_text) - overlap
|
|
||||||
break
|
|
||||||
|
|
||||||
# Text before the section start is content.
|
|
||||||
if best_start > pos:
|
|
||||||
content_segments.append(current_text[pos:best_start])
|
|
||||||
|
|
||||||
# Skip past the section region.
|
|
||||||
inner_start = best_start + best_variant_len
|
|
||||||
# Find matching section-end.
|
|
||||||
best_end = -1
|
|
||||||
best_end_variant_len = 0
|
|
||||||
for variant in self.tool_calls_section_end_variants:
|
|
||||||
idx = current_text.find(variant, inner_start)
|
|
||||||
if idx != -1 and (best_end == -1 or idx < best_end):
|
|
||||||
best_end = idx
|
|
||||||
best_end_variant_len = len(variant)
|
|
||||||
|
|
||||||
if best_end != -1:
|
|
||||||
pos = best_end + best_end_variant_len
|
|
||||||
else:
|
|
||||||
# Section still open — park cursor, stop.
|
|
||||||
pos = best_start
|
|
||||||
break
|
|
||||||
|
|
||||||
if content_segments:
|
|
||||||
self._sent_content_idx = pos
|
|
||||||
return "".join(content_segments)
|
|
||||||
if pos > self._sent_content_idx:
|
|
||||||
self._sent_content_idx = pos
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _extract_tool_call_regions(
|
def _extract_tool_call_regions(
|
||||||
self, text: str
|
self, text: str
|
||||||
) -> list[tuple[str, bool]]:
|
) -> list[tuple[str, bool]]:
|
||||||
@@ -499,16 +429,15 @@ class KimiK2ToolParser(ToolParser):
|
|||||||
delta_token_ids: Sequence[int],
|
delta_token_ids: Sequence[int],
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> DeltaMessage | None:
|
) -> DeltaMessage | None:
|
||||||
"""Extract tool calls from streaming output using re-parse-and-diff.
|
"""Extract tool calls from streaming output.
|
||||||
|
|
||||||
On every call we:
|
Hybrid approach:
|
||||||
1. Re-scan *current_text* for content outside tool-call sections.
|
- **Content forwarding** uses ``delta_text`` (same as the
|
||||||
2. Find all ``<|tool_call_begin|>`` regions (complete + partial).
|
original parser) so we never re-emit text that the reasoning
|
||||||
3. Parse each region for tool ID and arguments.
|
parser already handled.
|
||||||
4. Diff arguments against previous state, emit deltas.
|
- **Tool call detection** re-parses ``current_text`` on every
|
||||||
|
call (the re-parse-and-diff approach) so it's agnostic to
|
||||||
Because the entire text is re-parsed each time, the result is
|
how many tokens arrived per step — robust against MTP.
|
||||||
correct regardless of how many tokens arrived in this step.
|
|
||||||
"""
|
"""
|
||||||
logger.debug("delta_text: %s", delta_text)
|
logger.debug("delta_text: %s", delta_text)
|
||||||
logger.debug("delta_token_ids: %s", delta_token_ids)
|
logger.debug("delta_token_ids: %s", delta_token_ids)
|
||||||
@@ -521,10 +450,15 @@ class KimiK2ToolParser(ToolParser):
|
|||||||
if not self._tools_enabled(request):
|
if not self._tools_enabled(request):
|
||||||
return DeltaMessage(content=delta_text) if delta_text else None
|
return DeltaMessage(content=delta_text) if delta_text else None
|
||||||
|
|
||||||
# 1. Extract any content outside tool-call sections.
|
# ── Determine section state from full text (MTP-safe) ──
|
||||||
content = self._extract_content(current_text)
|
inner_start, inner_end = self._find_section_start_end(current_text)
|
||||||
|
in_open_section = inner_start != -1 and inner_end == -1
|
||||||
|
|
||||||
# 2. Find all tool-call regions.
|
# Was the section already open in previous_text?
|
||||||
|
prev_inner_start, _ = self._find_section_start_end(previous_text)
|
||||||
|
section_existed_before = prev_inner_start != -1
|
||||||
|
|
||||||
|
# ── Re-parse tool calls from current_text (MTP-safe) ──
|
||||||
regions = self._extract_tool_call_regions(current_text)
|
regions = self._extract_tool_call_regions(current_text)
|
||||||
tool_call_deltas: list[DeltaToolCall] = []
|
tool_call_deltas: list[DeltaToolCall] = []
|
||||||
|
|
||||||
@@ -567,30 +501,54 @@ class KimiK2ToolParser(ToolParser):
|
|||||||
if regions:
|
if regions:
|
||||||
self.current_tool_id = len(regions) - 1
|
self.current_tool_id = len(regions) - 1
|
||||||
|
|
||||||
# 3. Return a delta if we have content or tool-call updates.
|
# ── Emit results ──
|
||||||
if content or tool_call_deltas:
|
|
||||||
kwargs: dict[str, Any] = {}
|
|
||||||
if content:
|
|
||||||
kwargs["content"] = content
|
|
||||||
if tool_call_deltas:
|
|
||||||
kwargs["tool_calls"] = tool_call_deltas
|
|
||||||
return DeltaMessage(**kwargs)
|
|
||||||
|
|
||||||
# If we're inside an OPEN tool section (section start found but
|
# Case 1: We have tool call updates — emit them.
|
||||||
# not yet closed), return an empty delta to keep the stream alive
|
if tool_call_deltas:
|
||||||
# while waiting for tool call tokens to arrive. Returning None
|
return DeltaMessage(tool_calls=tool_call_deltas)
|
||||||
# here would cause the serving layer to think there's nothing to
|
|
||||||
# send, potentially terminating the stream before the tool call
|
# Case 2: No tool section has started yet — forward delta_text
|
||||||
# is complete.
|
# as content. The reasoning parser handles the reasoning/content
|
||||||
inner_start, inner_end = self._find_section_start_end(current_text)
|
# split; we just pass through whatever delta the serving layer
|
||||||
if inner_start != -1 and inner_end == -1:
|
# gave us.
|
||||||
# Section is open — still generating tool calls
|
if inner_start == -1:
|
||||||
|
return DeltaMessage(content=delta_text) if delta_text else None
|
||||||
|
|
||||||
|
# Case 3: The section just appeared in this delta. Extract any
|
||||||
|
# content that came before the section marker in this delta
|
||||||
|
# (e.g. "Let me check.<|tool_calls_section_begin|>").
|
||||||
|
if not section_existed_before:
|
||||||
|
section_start_in_text = self._find_section_start(current_text)
|
||||||
|
pre_section = current_text[len(previous_text):section_start_in_text]
|
||||||
|
if pre_section.strip():
|
||||||
|
return DeltaMessage(content=pre_section)
|
||||||
return DeltaMessage(content="")
|
return DeltaMessage(content="")
|
||||||
|
|
||||||
# No tool section, no content — forward delta as content.
|
# Case 4: Inside an open tool section but tool calls aren't
|
||||||
# This handles the case where the tool parser is called during
|
# parseable yet — emit empty delta to keep the stream alive.
|
||||||
# normal content generation (before any tool calls).
|
if in_open_section:
|
||||||
if delta_text:
|
return DeltaMessage(content="")
|
||||||
return DeltaMessage(content=delta_text)
|
|
||||||
|
# Case 5: Section is closed and we're past it — forward any
|
||||||
|
# new content that appeared after the section end marker.
|
||||||
|
if inner_end != -1:
|
||||||
|
for variant in self.tool_calls_section_end_variants:
|
||||||
|
end_marker_pos = current_text.find(variant, inner_start)
|
||||||
|
if end_marker_pos != -1:
|
||||||
|
after_section = current_text[
|
||||||
|
end_marker_pos + len(variant):
|
||||||
|
]
|
||||||
|
# Only emit what's new (not previously seen)
|
||||||
|
prev_after_len = 0
|
||||||
|
prev_end_pos = previous_text.find(variant)
|
||||||
|
if prev_end_pos != -1:
|
||||||
|
prev_after_len = len(
|
||||||
|
previous_text[prev_end_pos + len(variant):]
|
||||||
|
)
|
||||||
|
new_after = after_section[prev_after_len:]
|
||||||
|
if new_after:
|
||||||
|
return DeltaMessage(content=new_after)
|
||||||
|
break
|
||||||
|
return DeltaMessage(content="") if delta_text else None
|
||||||
|
|
||||||
return None
|
return None
|
||||||
Reference in New Issue
Block a user