fix: update kimi k2 tool parser logic (#31207)
Signed-off-by: wangln19 <wanglinian@dev.wanglinian.msh-dev.svc.cluster.local> Signed-off-by: Wang Linian <wanglinian@stu.pku.edu.cn> Co-authored-by: wangln19 <wanglinian@dev.wanglinian.msh-dev.svc.cluster.local> Co-authored-by: Chauncey <chaunceyjiang@gmail.com>
This commit is contained in:
@@ -44,6 +44,33 @@ def assert_tool_calls(
|
||||
)
|
||||
|
||||
|
||||
def run_streaming_sequence(parser, deltas):
|
||||
"""Helper to simulate a streaming sequence and return results."""
|
||||
previous_text = ""
|
||||
previous_token_ids: list[int] = []
|
||||
results = []
|
||||
|
||||
for delta_text, delta_token_ids in deltas:
|
||||
current_text = previous_text + delta_text
|
||||
current_token_ids = previous_token_ids + delta_token_ids
|
||||
|
||||
result = parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=delta_token_ids,
|
||||
request=None,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
previous_text = current_text
|
||||
previous_token_ids = current_token_ids
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_extract_tool_calls_no_tools(kimi_k2_tool_parser):
|
||||
model_output = "This is a test"
|
||||
extracted_tool_calls = kimi_k2_tool_parser.extract_tool_calls(
|
||||
@@ -346,61 +373,32 @@ def test_token_leak_between_section_and_tool_begin(kimi_k2_tool_parser):
|
||||
tool_call_begin_token_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
|
||||
# Simulate streaming sequence:
|
||||
deltas = [
|
||||
("I'll help you with that. ", [1, 2, 3]),
|
||||
("<|tool_calls_section_begin|>", [section_begin_token_id]),
|
||||
(" spurious text ", [4, 5]),
|
||||
("<|tool_call_begin|>", [tool_call_begin_token_id]),
|
||||
]
|
||||
|
||||
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
|
||||
|
||||
# Delta 1: "I'll help you with that. "
|
||||
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="I'll help you with that. ",
|
||||
delta_text="I'll help you with that. ",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[1, 2, 3], # Regular tokens
|
||||
delta_token_ids=[1, 2, 3],
|
||||
request=None,
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result1.content == "I'll help you with that. "
|
||||
assert results[0] is not None
|
||||
assert results[0].content == "I'll help you with that. "
|
||||
|
||||
# Delta 2: "<|tool_calls_section_begin|>"
|
||||
prev_ids = [1, 2, 3]
|
||||
curr_ids = prev_ids + [section_begin_token_id]
|
||||
result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. ",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[section_begin_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Section marker should be stripped and suppressed
|
||||
assert result2 is None or (result2.content is None or result2.content == "")
|
||||
assert results[1] is None or (
|
||||
results[1].content is None or results[1].content == ""
|
||||
)
|
||||
|
||||
# Delta 3: " spurious text or tokens " (THE LEAK SCENARIO)
|
||||
prev_ids = curr_ids
|
||||
curr_ids = curr_ids + [4, 5]
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. <|tool_calls_section_begin|>",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
|
||||
delta_text=" spurious text ",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[4, 5],
|
||||
request=None,
|
||||
)
|
||||
# CRITICAL: This text should be suppressed, NOT returned as reasoning_delta
|
||||
assert result3 is None or (result3.content is None or result3.content == "")
|
||||
assert results[2] is None or (
|
||||
results[2].content is None or results[2].content == ""
|
||||
)
|
||||
|
||||
# Delta 4: "<|tool_call_begin|>..."
|
||||
prev_ids = curr_ids
|
||||
curr_ids = curr_ids + [tool_call_begin_token_id]
|
||||
_result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="I'll help you with that. <|tool_calls_section_begin|> spurious text ",
|
||||
current_text="I'll help you with that. <|tool_calls_section_begin|> spurious text <|tool_call_begin|>",
|
||||
delta_text="<|tool_call_begin|>",
|
||||
previous_token_ids=prev_ids,
|
||||
current_token_ids=curr_ids,
|
||||
delta_token_ids=[tool_call_begin_token_id],
|
||||
request=None,
|
||||
)
|
||||
# Now we're in tool call mode, result depends on internal state
|
||||
# The key is that the spurious text from Delta 3 was not leaked
|
||||
|
||||
@@ -416,31 +414,15 @@ def test_split_markers_across_deltas(kimi_k2_tool_parser):
|
||||
"<|tool_calls_section_begin|>"
|
||||
)
|
||||
|
||||
# Delta 1: "...reasoning<|tool_calls_sec"
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning",
|
||||
current_text="Some reasoning<|tool_calls_sec",
|
||||
delta_text="<|tool_calls_sec",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, 3], # Partial token
|
||||
delta_token_ids=[3],
|
||||
request=None,
|
||||
)
|
||||
# Partial token not recognized yet, might be buffered
|
||||
# Should return as content or None (depends on implementation)
|
||||
# Delta 1: partial token, Delta 2: complete marker
|
||||
deltas = [
|
||||
("<|tool_calls_sec", [3]),
|
||||
("tion_begin|> ", [section_begin_token_id, 4]),
|
||||
]
|
||||
|
||||
_results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
|
||||
|
||||
# Delta 2: "tion_begin|> " (completes the marker)
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Some reasoning<|tool_calls_sec",
|
||||
current_text="Some reasoning<|tool_calls_section_begin|> ",
|
||||
delta_text="tion_begin|> ",
|
||||
previous_token_ids=[1, 2, 3],
|
||||
current_token_ids=[1, 2, section_begin_token_id, 4],
|
||||
delta_token_ids=[section_begin_token_id, 4],
|
||||
request=None,
|
||||
)
|
||||
# Now the complete marker should be detected via buffer
|
||||
# The parser should enter tool section mode
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
|
||||
@@ -475,42 +457,17 @@ def test_reentry_to_reasoning_after_tool_section(kimi_k2_tool_parser):
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
|
||||
# Enter tool section
|
||||
_result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="<|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
deltas = [
|
||||
("<|tool_calls_section_begin|>", [section_begin_id]),
|
||||
("<|tool_calls_section_end|>", [section_end_id]),
|
||||
(" More reasoning", [10, 11]),
|
||||
]
|
||||
|
||||
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
|
||||
|
||||
# Exit tool section
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|>",
|
||||
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
delta_text="<|tool_calls_section_end|>",
|
||||
previous_token_ids=[section_begin_id],
|
||||
current_token_ids=[section_begin_id, section_end_id],
|
||||
delta_token_ids=[section_end_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Subsequent reasoning text should be returned normally
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="<|tool_calls_section_begin|><|tool_calls_section_end|>",
|
||||
current_text="<|tool_calls_section_begin|><|tool_calls_section_end|> More reasoning",
|
||||
delta_text=" More reasoning",
|
||||
previous_token_ids=[section_begin_id, section_end_id],
|
||||
current_token_ids=[section_begin_id, section_end_id, 10, 11],
|
||||
delta_token_ids=[10, 11],
|
||||
request=None,
|
||||
)
|
||||
assert result3 is not None
|
||||
assert result3.content == " More reasoning"
|
||||
assert results[2] is not None
|
||||
assert results[2].content == " More reasoning"
|
||||
|
||||
|
||||
def test_empty_tool_section(kimi_k2_tool_parser):
|
||||
@@ -819,106 +776,150 @@ def test_tool_call_end_and_section_end_same_chunk(kimi_k2_tool_parser):
|
||||
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
|
||||
|
||||
# Simulate a streaming sequence for a SHORT tool call (all in one chunk):
|
||||
# 1. Reasoning text
|
||||
result1 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="",
|
||||
current_text="Let me help. ",
|
||||
delta_text="Let me help. ",
|
||||
previous_token_ids=[],
|
||||
current_token_ids=[1, 2],
|
||||
delta_token_ids=[1, 2],
|
||||
request=None,
|
||||
)
|
||||
assert result1 is not None
|
||||
assert result1.content == "Let me help. "
|
||||
|
||||
# 2. Section begin
|
||||
_result2 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text="Let me help. ",
|
||||
current_text="Let me help. <|tool_calls_section_begin|>",
|
||||
delta_text="<|tool_calls_section_begin|>",
|
||||
previous_token_ids=[1, 2],
|
||||
current_token_ids=[1, 2, section_begin_id],
|
||||
delta_token_ids=[section_begin_id],
|
||||
request=None,
|
||||
)
|
||||
assert kimi_k2_tool_parser.in_tool_section is True
|
||||
|
||||
# 3. Tool call begin + full content + tool_end + section_end ALL IN ONE CHUNK
|
||||
# This is the critical scenario for short tool calls
|
||||
combined = (
|
||||
'<|tool_call_begin|>get_weather:0 <|tool_call_argument_begin|> {"city": "Paris"} '
|
||||
"<|tool_call_end|><|tool_calls_section_end|>"
|
||||
)
|
||||
|
||||
# Build up the previous text gradually to simulate realistic streaming
|
||||
prev_text = "Let me help. <|tool_calls_section_begin|>"
|
||||
curr_text = prev_text + combined
|
||||
deltas = [
|
||||
("Let me help. ", [1, 2]),
|
||||
("<|tool_calls_section_begin|>", [section_begin_id]),
|
||||
(combined, [tool_begin_id, 10, 11, 12, tool_end_id, section_end_id]),
|
||||
(" Done", [20]),
|
||||
]
|
||||
|
||||
result3 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=prev_text,
|
||||
current_text=curr_text,
|
||||
delta_text=combined,
|
||||
previous_token_ids=[1, 2, section_begin_id],
|
||||
current_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
],
|
||||
delta_token_ids=[tool_begin_id, 10, 11, 12, tool_end_id, section_end_id],
|
||||
request=None,
|
||||
)
|
||||
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
|
||||
|
||||
# CRITICAL: Parser should have exited section AFTER processing tool
|
||||
assert kimi_k2_tool_parser.in_tool_section is False
|
||||
|
||||
# Tool call should have been emitted (not dropped)
|
||||
# The result might be the tool name or None depending on state, but
|
||||
# importantly, it shouldn't be returning the literal tokens as content
|
||||
|
||||
if result3 is not None and result3.content is not None:
|
||||
if results[2] is not None and results[2].content is not None:
|
||||
# Verify no special tokens leaked into content
|
||||
assert "<|tool_call_end|>" not in result3.content
|
||||
assert "<|tool_calls_section_end|>" not in result3.content
|
||||
|
||||
# 4. Verify subsequent content streams normally
|
||||
result4 = kimi_k2_tool_parser.extract_tool_calls_streaming(
|
||||
previous_text=curr_text,
|
||||
current_text=curr_text + " Done",
|
||||
delta_text=" Done",
|
||||
previous_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
],
|
||||
current_token_ids=[
|
||||
1,
|
||||
2,
|
||||
section_begin_id,
|
||||
tool_begin_id,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
tool_end_id,
|
||||
section_end_id,
|
||||
20,
|
||||
],
|
||||
delta_token_ids=[20],
|
||||
request=None,
|
||||
)
|
||||
assert "<|tool_call_end|>" not in results[2].content
|
||||
assert "<|tool_calls_section_end|>" not in results[2].content
|
||||
|
||||
# Content after tool section should stream normally
|
||||
assert result4 is not None
|
||||
assert result4.content == " Done"
|
||||
assert results[3] is not None
|
||||
assert results[3].content == " Done"
|
||||
|
||||
|
||||
def test_streaming_tool_call_markers_not_leaked(kimi_k2_tool_parser):
|
||||
"""
|
||||
CRITICAL TEST: Verify that tool call markers (<|tool_call_begin|>,
|
||||
<|tool_call_end|>, <|tool_call_argument_begin|>) are NOT leaked
|
||||
into the content field during streaming.
|
||||
|
||||
This reproduces the AWS Bedrock bug where tool call markers appeared
|
||||
in the 'text' field of responses.
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
|
||||
|
||||
# List of markers that should NEVER appear in content
|
||||
forbidden_markers = [
|
||||
"<|tool_call_begin|>",
|
||||
"<|tool_call_end|>",
|
||||
"<|tool_call_argument_begin|>",
|
||||
"<|tool_calls_section_begin|>",
|
||||
"<|tool_calls_section_end|>",
|
||||
]
|
||||
|
||||
all_content = []
|
||||
|
||||
# Steps: reasoning, section begin, tool call, section end, more reasoning
|
||||
tool_chunk = (
|
||||
"<|tool_call_begin|> functions.get_weather:0 "
|
||||
'<|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
|
||||
)
|
||||
deltas = [
|
||||
("I'll check the weather. ", [1, 2, 3]),
|
||||
("<|tool_calls_section_begin|>", [section_begin_id]),
|
||||
(tool_chunk, [tool_begin_id, 10, 11, tool_end_id]),
|
||||
("<|tool_calls_section_end|>", [section_end_id]),
|
||||
(" Here's the result.", [20, 21]),
|
||||
]
|
||||
|
||||
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
|
||||
|
||||
for res in results:
|
||||
if res and res.content:
|
||||
all_content.append(res.content)
|
||||
|
||||
# CRITICAL ASSERTIONS: No forbidden markers in any content
|
||||
full_content = "".join(all_content)
|
||||
for marker in forbidden_markers:
|
||||
assert marker not in full_content, (
|
||||
f"MARKER LEAK DETECTED: '{marker}' found in content. "
|
||||
f"Full content: {repr(full_content)}"
|
||||
)
|
||||
|
||||
# Also check that tool call content (function name, arguments) is not leaked
|
||||
assert "get_weather" not in full_content, (
|
||||
f"TOOL CALL CONTENT LEAKED: 'get_weather' found in content. "
|
||||
f"Full content: {repr(full_content)}"
|
||||
)
|
||||
assert "Tokyo" not in full_content, (
|
||||
f"TOOL CALL CONTENT LEAKED: 'Tokyo' found in content. "
|
||||
f"Full content: {repr(full_content)}"
|
||||
)
|
||||
|
||||
# Verify that legitimate content was preserved
|
||||
assert "I'll check the weather." in full_content or len(all_content) > 0
|
||||
|
||||
|
||||
def test_streaming_multiple_tool_calls_not_leaked(kimi_k2_tool_parser):
|
||||
"""
|
||||
Test that MULTIPLE tool calls in streaming mode do not leak into content.
|
||||
This reproduces the AWS Bedrock scenario: "Compare weather in Tokyo and NYC".
|
||||
"""
|
||||
kimi_k2_tool_parser.reset_streaming_state()
|
||||
|
||||
section_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_begin|>")
|
||||
section_end_id = kimi_k2_tool_parser.vocab.get("<|tool_calls_section_end|>")
|
||||
tool_begin_id = kimi_k2_tool_parser.vocab.get("<|tool_call_begin|>")
|
||||
tool_end_id = kimi_k2_tool_parser.vocab.get("<|tool_call_end|>")
|
||||
|
||||
all_content = []
|
||||
|
||||
tool1 = '<|tool_call_begin|> get_weather:0 <|tool_call_argument_begin|> {"city": "Tokyo"} <|tool_call_end|>'
|
||||
tool2 = ' <|tool_call_begin|> get_weather:1 <|tool_call_argument_begin|> {"city": "New York"} <|tool_call_end|>'
|
||||
|
||||
deltas = [
|
||||
("I'll compare the weather. ", [1, 2, 3]),
|
||||
("<|tool_calls_section_begin|>", [section_begin_id]),
|
||||
(tool1, [tool_begin_id, 10, tool_end_id]),
|
||||
(tool2, [tool_begin_id, 20, tool_end_id]),
|
||||
("<|tool_calls_section_end|>", [section_end_id]),
|
||||
(" Here's the comparison.", [30]),
|
||||
]
|
||||
|
||||
results = run_streaming_sequence(kimi_k2_tool_parser, deltas)
|
||||
|
||||
for res in results:
|
||||
if res and res.content:
|
||||
all_content.append(res.content)
|
||||
|
||||
# Assertions
|
||||
full_content = "".join(all_content)
|
||||
|
||||
# Check no markers leaked
|
||||
forbidden = ["<|tool_call", "<|tool_calls_section"]
|
||||
for marker in forbidden:
|
||||
assert marker not in full_content, (
|
||||
f"MARKER LEAKED: {marker} in {repr(full_content)}"
|
||||
)
|
||||
|
||||
# Check no tool call content leaked (both tools)
|
||||
assert "get_weather" not in full_content, f"TOOL NAME LEAKED: {repr(full_content)}"
|
||||
assert "Tokyo" not in full_content, f"TOOL ARG LEAKED (Tokyo): {repr(full_content)}"
|
||||
assert "New York" not in full_content, (
|
||||
f"TOOL ARG LEAKED (NYC): {repr(full_content)}"
|
||||
)
|
||||
|
||||
# Legitimate content preserved
|
||||
assert "compare" in full_content.lower() or len(all_content) > 0
|
||||
|
||||
Reference in New Issue
Block a user