[Tool Parser] Fix Qwen3Coder streaming parameter loss with speculative decode (#35615)
Signed-off-by: Martin Vit <martin@voipmonitor.org> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get the expected call based on partial JSON
|
# get the expected call based on partial JSON
|
||||||
# parsing which "autocompletes" the JSON
|
# parsing which "autocompletes" the JSON.
|
||||||
expected_call = json.dumps(
|
# Tool parsers (e.g. Qwen3Coder) store
|
||||||
tool_parser.prev_tool_call_arr[index].get(
|
# arguments as a JSON string in
|
||||||
"arguments", {}
|
# prev_tool_call_arr. Calling json.dumps()
|
||||||
),
|
# on an already-serialized string would
|
||||||
ensure_ascii=False,
|
# double-serialize it (e.g. '{"k":1}' becomes
|
||||||
|
# '"{\\"k\\":1}"'), which then causes the
|
||||||
|
# replace() below to fail and append the
|
||||||
|
# entire double-serialized string as a
|
||||||
|
# spurious final delta.
|
||||||
|
args = tool_parser.prev_tool_call_arr[index].get(
|
||||||
|
"arguments", {}
|
||||||
)
|
)
|
||||||
|
if isinstance(args, str):
|
||||||
|
expected_call = args
|
||||||
|
else:
|
||||||
|
expected_call = json.dumps(args, ensure_ascii=False)
|
||||||
|
|
||||||
# get what we've streamed so far for arguments
|
# get what we've streamed so far for arguments
|
||||||
# for the current tool
|
# for the current tool
|
||||||
|
|||||||
@@ -479,20 +479,22 @@ class Qwen3CoderToolParser(ToolParser):
|
|||||||
self.header_sent = True
|
self.header_sent = True
|
||||||
self.in_function = True
|
self.in_function = True
|
||||||
|
|
||||||
# IMPORTANT: Add to prev_tool_call_arr immediately when
|
# Always append — each tool call is a separate
|
||||||
# we detect a tool call. This ensures
|
# invocation even if the function name is the same
|
||||||
# finish_reason="tool_calls" even if parsing isn't complete
|
# (e.g. two consecutive "read" calls).
|
||||||
already_added = any(
|
self.prev_tool_call_arr.append(
|
||||||
tool.get("name") == self.current_function_name
|
{
|
||||||
for tool in self.prev_tool_call_arr
|
"name": self.current_function_name,
|
||||||
|
"arguments": "{}",
|
||||||
|
}
|
||||||
)
|
)
|
||||||
if not already_added:
|
|
||||||
self.prev_tool_call_arr.append(
|
# Initialize streamed args tracking for this tool.
|
||||||
{
|
# The serving layer reads streamed_args_for_tool to
|
||||||
"name": self.current_function_name,
|
# compute remaining arguments at stream end. Without
|
||||||
"arguments": "{}", # Placeholder, will be updated later
|
# this, IndexError occurs when the serving layer
|
||||||
}
|
# accesses streamed_args_for_tool[index].
|
||||||
)
|
self.streamed_args_for_tool.append("")
|
||||||
|
|
||||||
# Send header with function info
|
# Send header with function info
|
||||||
return DeltaMessage(
|
return DeltaMessage(
|
||||||
@@ -511,9 +513,14 @@ class Qwen3CoderToolParser(ToolParser):
|
|||||||
|
|
||||||
# We've sent header, now handle function body
|
# We've sent header, now handle function body
|
||||||
if self.in_function:
|
if self.in_function:
|
||||||
# Send opening brace if not sent yet
|
# Always send opening brace first, regardless of whether
|
||||||
if not self.json_started and self.parameter_prefix not in delta_text:
|
# parameter_prefix is in the current delta. With speculative
|
||||||
|
# decoding, a single delta may contain both the opening brace
|
||||||
|
# and parameter data; skipping "{" here would desync
|
||||||
|
# json_started from what was actually streamed.
|
||||||
|
if not self.json_started:
|
||||||
self.json_started = True
|
self.json_started = True
|
||||||
|
self.streamed_args_for_tool[self.current_tool_index] += "{"
|
||||||
return DeltaMessage(
|
return DeltaMessage(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
DeltaToolCall(
|
DeltaToolCall(
|
||||||
@@ -523,25 +530,133 @@ class Qwen3CoderToolParser(ToolParser):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make sure json_started is set if we're processing parameters
|
# Find all parameter start positions in current tool_text
|
||||||
if not self.json_started:
|
param_starts = []
|
||||||
self.json_started = True
|
search_idx = 0
|
||||||
|
while True:
|
||||||
|
search_idx = tool_text.find(self.parameter_prefix, search_idx)
|
||||||
|
if search_idx == -1:
|
||||||
|
break
|
||||||
|
param_starts.append(search_idx)
|
||||||
|
search_idx += len(self.parameter_prefix)
|
||||||
|
|
||||||
# Check for function end in accumulated text
|
# Process ALL complete params in a loop (spec decode fix).
|
||||||
|
# With speculative decoding a single delta can deliver
|
||||||
|
# multiple complete parameters at once. The old single-pass
|
||||||
|
# code would process one and ``return None`` if the next was
|
||||||
|
# incomplete — skipping any already-complete params that
|
||||||
|
# preceded it. Using a loop with ``break`` instead ensures
|
||||||
|
# we emit every complete parameter before yielding control.
|
||||||
|
json_fragments = []
|
||||||
|
while not self.in_param and self.param_count < len(param_starts):
|
||||||
|
param_idx = param_starts[self.param_count]
|
||||||
|
param_start = param_idx + len(self.parameter_prefix)
|
||||||
|
remaining = tool_text[param_start:]
|
||||||
|
|
||||||
|
if ">" not in remaining:
|
||||||
|
break
|
||||||
|
|
||||||
|
name_end = remaining.find(">")
|
||||||
|
current_param_name = remaining[:name_end]
|
||||||
|
|
||||||
|
value_start = param_start + name_end + 1
|
||||||
|
value_text = tool_text[value_start:]
|
||||||
|
if value_text.startswith("\n"):
|
||||||
|
value_text = value_text[1:]
|
||||||
|
|
||||||
|
param_end_idx = value_text.find(self.parameter_end_token)
|
||||||
|
if param_end_idx == -1:
|
||||||
|
next_param_idx = value_text.find(self.parameter_prefix)
|
||||||
|
func_end_idx = value_text.find(self.function_end_token)
|
||||||
|
|
||||||
|
if next_param_idx != -1 and (
|
||||||
|
func_end_idx == -1 or next_param_idx < func_end_idx
|
||||||
|
):
|
||||||
|
param_end_idx = next_param_idx
|
||||||
|
elif func_end_idx != -1:
|
||||||
|
param_end_idx = func_end_idx
|
||||||
|
else:
|
||||||
|
# Fallback for malformed XML where </function>
|
||||||
|
# is missing. Use </tool_call> as a delimiter
|
||||||
|
# if present in the value so we don't include
|
||||||
|
# the closing tag as part of the param value.
|
||||||
|
tool_end_in_value = value_text.find(self.tool_call_end_token)
|
||||||
|
if tool_end_in_value != -1:
|
||||||
|
param_end_idx = tool_end_in_value
|
||||||
|
else:
|
||||||
|
# Parameter incomplete — break so we still
|
||||||
|
# emit any fragments accumulated by earlier
|
||||||
|
# loop iterations.
|
||||||
|
break
|
||||||
|
|
||||||
|
if param_end_idx == -1:
|
||||||
|
break
|
||||||
|
|
||||||
|
param_value = value_text[:param_end_idx]
|
||||||
|
if param_value.endswith("\n"):
|
||||||
|
param_value = param_value[:-1]
|
||||||
|
|
||||||
|
self.current_param_name = current_param_name
|
||||||
|
self.accumulated_params[current_param_name] = param_value
|
||||||
|
|
||||||
|
param_config = self._get_arguments_config(
|
||||||
|
self.current_function_name or "",
|
||||||
|
self.streaming_request.tools if self.streaming_request else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
converted_value = self._convert_param_value(
|
||||||
|
param_value,
|
||||||
|
current_param_name,
|
||||||
|
param_config,
|
||||||
|
self.current_function_name or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized_value = json.dumps(converted_value, ensure_ascii=False)
|
||||||
|
|
||||||
|
if self.param_count == 0:
|
||||||
|
json_fragment = f'"{current_param_name}": {serialized_value}'
|
||||||
|
else:
|
||||||
|
json_fragment = f', "{current_param_name}": {serialized_value}'
|
||||||
|
|
||||||
|
self.param_count += 1
|
||||||
|
json_fragments.append(json_fragment)
|
||||||
|
|
||||||
|
if json_fragments:
|
||||||
|
combined = "".join(json_fragments)
|
||||||
|
|
||||||
|
if self.current_tool_index < len(self.streamed_args_for_tool):
|
||||||
|
self.streamed_args_for_tool[self.current_tool_index] += combined
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"streamed_args_for_tool out of sync: index=%d len=%d",
|
||||||
|
self.current_tool_index,
|
||||||
|
len(self.streamed_args_for_tool),
|
||||||
|
)
|
||||||
|
|
||||||
|
return DeltaMessage(
|
||||||
|
tool_calls=[
|
||||||
|
DeltaToolCall(
|
||||||
|
index=self.current_tool_index,
|
||||||
|
function=DeltaFunctionCall(arguments=combined),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for function end AFTER processing parameters.
|
||||||
|
# This ordering is critical: with speculative decoding a
|
||||||
|
# burst can deliver the final parameter value together with
|
||||||
|
# </function>. If the close check ran first it would emit
|
||||||
|
# "}" and set in_function=False before the parameter loop
|
||||||
|
# ever ran, causing the parameter to be silently dropped.
|
||||||
if not self.json_closed and self.function_end_token in tool_text:
|
if not self.json_closed and self.function_end_token in tool_text:
|
||||||
# Close JSON
|
|
||||||
self.json_closed = True
|
self.json_closed = True
|
||||||
|
|
||||||
# Extract complete tool call to update
|
|
||||||
# prev_tool_call_arr with final arguments
|
|
||||||
# Find the function content
|
|
||||||
func_start = tool_text.find(self.tool_call_prefix) + len(
|
func_start = tool_text.find(self.tool_call_prefix) + len(
|
||||||
self.tool_call_prefix
|
self.tool_call_prefix
|
||||||
)
|
)
|
||||||
func_content_end = tool_text.find(self.function_end_token, func_start)
|
func_content_end = tool_text.find(self.function_end_token, func_start)
|
||||||
if func_content_end != -1:
|
if func_content_end != -1:
|
||||||
func_content = tool_text[func_start:func_content_end]
|
func_content = tool_text[func_start:func_content_end]
|
||||||
# Parse to get the complete arguments
|
|
||||||
try:
|
try:
|
||||||
parsed_tool = self._parse_xml_function_call(
|
parsed_tool = self._parse_xml_function_call(
|
||||||
func_content,
|
func_content,
|
||||||
@@ -549,16 +664,27 @@ class Qwen3CoderToolParser(ToolParser):
|
|||||||
if self.streaming_request
|
if self.streaming_request
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
if parsed_tool:
|
if parsed_tool and self.current_tool_index < len(
|
||||||
# Update existing entry in
|
self.prev_tool_call_arr
|
||||||
# prev_tool_call_arr with complete args
|
):
|
||||||
for i, tool in enumerate(self.prev_tool_call_arr):
|
self.prev_tool_call_arr[self.current_tool_index][
|
||||||
if tool.get("name") == parsed_tool.function.name:
|
"arguments"
|
||||||
args = parsed_tool.function.arguments
|
] = parsed_tool.function.arguments
|
||||||
self.prev_tool_call_arr[i]["arguments"] = args
|
|
||||||
break
|
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # Ignore parsing errors during streaming
|
logger.debug(
|
||||||
|
"Failed to parse tool call during streaming: %s",
|
||||||
|
tool_text,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.current_tool_index < len(self.streamed_args_for_tool):
|
||||||
|
self.streamed_args_for_tool[self.current_tool_index] += "}"
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"streamed_args_for_tool out of sync: index=%d len=%d",
|
||||||
|
self.current_tool_index,
|
||||||
|
len(self.streamed_args_for_tool),
|
||||||
|
)
|
||||||
|
|
||||||
result = DeltaMessage(
|
result = DeltaMessage(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
@@ -569,215 +695,10 @@ class Qwen3CoderToolParser(ToolParser):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset state for next tool
|
|
||||||
self.in_function = False
|
self.in_function = False
|
||||||
self.json_closed = True
|
self.json_closed = True
|
||||||
self.accumulated_params = {}
|
self.accumulated_params = {}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# Look for parameters
|
|
||||||
# Find all parameter starts
|
|
||||||
param_starts = []
|
|
||||||
idx = 0
|
|
||||||
while True:
|
|
||||||
idx = tool_text.find(self.parameter_prefix, idx)
|
|
||||||
if idx == -1:
|
|
||||||
break
|
|
||||||
param_starts.append(idx)
|
|
||||||
idx += len(self.parameter_prefix)
|
|
||||||
|
|
||||||
# Check if we should start a new parameter
|
|
||||||
if (
|
|
||||||
not self.in_param
|
|
||||||
and self.param_count < len(param_starts)
|
|
||||||
and len(param_starts) > self.param_count
|
|
||||||
):
|
|
||||||
# Process the next parameter
|
|
||||||
param_idx = param_starts[self.param_count]
|
|
||||||
param_start = param_idx + len(self.parameter_prefix)
|
|
||||||
remaining = tool_text[param_start:]
|
|
||||||
|
|
||||||
if ">" in remaining:
|
|
||||||
# We have the complete parameter name
|
|
||||||
name_end = remaining.find(">")
|
|
||||||
self.current_param_name = remaining[:name_end]
|
|
||||||
|
|
||||||
# Find the parameter value
|
|
||||||
value_start = param_start + name_end + 1
|
|
||||||
value_text = tool_text[value_start:]
|
|
||||||
if value_text.startswith("\n"):
|
|
||||||
value_text = value_text[1:]
|
|
||||||
|
|
||||||
# Find where this parameter ends
|
|
||||||
param_end_idx = value_text.find(self.parameter_end_token)
|
|
||||||
if param_end_idx == -1:
|
|
||||||
# No closing tag, look for next parameter or
|
|
||||||
# function end
|
|
||||||
next_param_idx = value_text.find(self.parameter_prefix)
|
|
||||||
func_end_idx = value_text.find(self.function_end_token)
|
|
||||||
|
|
||||||
if next_param_idx != -1 and (
|
|
||||||
func_end_idx == -1 or next_param_idx < func_end_idx
|
|
||||||
):
|
|
||||||
param_end_idx = next_param_idx
|
|
||||||
elif func_end_idx != -1:
|
|
||||||
param_end_idx = func_end_idx
|
|
||||||
else:
|
|
||||||
# Neither found, check if tool call is complete
|
|
||||||
if self.tool_call_end_token in tool_text:
|
|
||||||
# Tool call is complete, so parameter
|
|
||||||
# must be complete too. Use all
|
|
||||||
# remaining text before function end
|
|
||||||
param_end_idx = len(value_text)
|
|
||||||
else:
|
|
||||||
# Still streaming, wait for more content
|
|
||||||
return None
|
|
||||||
|
|
||||||
if param_end_idx != -1:
|
|
||||||
# Complete parameter found
|
|
||||||
param_value = value_text[:param_end_idx]
|
|
||||||
if param_value.endswith("\n"):
|
|
||||||
param_value = param_value[:-1]
|
|
||||||
|
|
||||||
# Store raw value for later processing
|
|
||||||
self.accumulated_params[self.current_param_name] = param_value
|
|
||||||
|
|
||||||
# Get parameter configuration for type conversion
|
|
||||||
param_config = self._get_arguments_config(
|
|
||||||
self.current_function_name or "",
|
|
||||||
self.streaming_request.tools
|
|
||||||
if self.streaming_request
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert param value to appropriate type
|
|
||||||
converted_value = self._convert_param_value(
|
|
||||||
param_value,
|
|
||||||
self.current_param_name,
|
|
||||||
param_config,
|
|
||||||
self.current_function_name or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build JSON fragment based on the converted type
|
|
||||||
# Use json.dumps to properly serialize the value
|
|
||||||
serialized_value = json.dumps(
|
|
||||||
converted_value, ensure_ascii=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.param_count == 0:
|
|
||||||
json_fragment = (
|
|
||||||
f'"{self.current_param_name}": {serialized_value}'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
json_fragment = (
|
|
||||||
f', "{self.current_param_name}": {serialized_value}'
|
|
||||||
)
|
|
||||||
|
|
||||||
self.param_count += 1
|
|
||||||
|
|
||||||
return DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
index=self.current_tool_index,
|
|
||||||
function=DeltaFunctionCall(arguments=json_fragment),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Continue parameter value - Not used in the current implementation
|
|
||||||
# since we process complete parameters above
|
|
||||||
if self.in_param:
|
|
||||||
if self.parameter_end_token in delta_text:
|
|
||||||
# End of parameter
|
|
||||||
end_idx = delta_text.find(self.parameter_end_token)
|
|
||||||
value_chunk = delta_text[:end_idx]
|
|
||||||
|
|
||||||
# Skip past > if at start
|
|
||||||
if not self.current_param_value and ">" in value_chunk:
|
|
||||||
gt_idx = value_chunk.find(">")
|
|
||||||
value_chunk = value_chunk[gt_idx + 1 :]
|
|
||||||
|
|
||||||
if not self.current_param_value and value_chunk.startswith("\n"):
|
|
||||||
value_chunk = value_chunk[1:]
|
|
||||||
|
|
||||||
# Store complete value
|
|
||||||
full_value = self.current_param_value + value_chunk
|
|
||||||
self.accumulated_params[self.current_param_name] = full_value
|
|
||||||
|
|
||||||
# Get parameter configuration for type conversion
|
|
||||||
param_config = self._get_arguments_config(
|
|
||||||
self.current_function_name or "",
|
|
||||||
self.streaming_request.tools
|
|
||||||
if self.streaming_request
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert the parameter value to the appropriate type
|
|
||||||
converted_value = self._convert_param_value(
|
|
||||||
full_value,
|
|
||||||
self.current_param_name or "",
|
|
||||||
param_config,
|
|
||||||
self.current_function_name or "",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Serialize the converted value
|
|
||||||
serialized_value = json.dumps(converted_value, ensure_ascii=False)
|
|
||||||
|
|
||||||
# Since we've been streaming the quoted version,
|
|
||||||
# we need to close it properly
|
|
||||||
# This is complex - for now just complete the value
|
|
||||||
self.in_param = False
|
|
||||||
self.current_param_value = ""
|
|
||||||
|
|
||||||
# Just close the current parameter string
|
|
||||||
return DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
index=self.current_tool_index,
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
arguments='"'
|
|
||||||
), # Close the string quote
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Continue accumulating value
|
|
||||||
value_chunk = delta_text
|
|
||||||
|
|
||||||
# Handle first chunk after param name
|
|
||||||
if not self.current_param_value and ">" in value_chunk:
|
|
||||||
gt_idx = value_chunk.find(">")
|
|
||||||
value_chunk = value_chunk[gt_idx + 1 :]
|
|
||||||
|
|
||||||
if not self.current_param_value and value_chunk.startswith("\n"):
|
|
||||||
value_chunk = value_chunk[1:]
|
|
||||||
|
|
||||||
if value_chunk:
|
|
||||||
# Stream the escaped delta
|
|
||||||
prev_escaped = (
|
|
||||||
json.dumps(self.current_param_value, ensure_ascii=False)[
|
|
||||||
1:-1
|
|
||||||
]
|
|
||||||
if self.current_param_value
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
self.current_param_value += value_chunk
|
|
||||||
full_escaped = json.dumps(
|
|
||||||
self.current_param_value, ensure_ascii=False
|
|
||||||
)[1:-1]
|
|
||||||
delta_escaped = full_escaped[len(prev_escaped) :]
|
|
||||||
|
|
||||||
if delta_escaped:
|
|
||||||
return DeltaMessage(
|
|
||||||
tool_calls=[
|
|
||||||
DeltaToolCall(
|
|
||||||
index=self.current_tool_index,
|
|
||||||
function=DeltaFunctionCall(
|
|
||||||
arguments=delta_escaped
|
|
||||||
),
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|||||||
Reference in New Issue
Block a user