[Tool Parser] Fix Qwen3Coder streaming parameter loss with speculative decode (#35615)

Signed-off-by: Martin Vit <martin@voipmonitor.org>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Martin Vit
2026-03-03 02:40:37 +01:00
committed by GitHub
parent 168ee03e1c
commit 8ebd872f50
2 changed files with 175 additions and 244 deletions

View File

@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
) )
# get the expected call based on partial JSON # get the expected call based on partial JSON
# parsing which "autocompletes" the JSON # parsing which "autocompletes" the JSON.
expected_call = json.dumps( # Tool parsers (e.g. Qwen3Coder) store
tool_parser.prev_tool_call_arr[index].get( # arguments as a JSON string in
# prev_tool_call_arr. Calling json.dumps()
# on an already-serialized string would
# double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# spurious final delta.
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {} "arguments", {}
),
ensure_ascii=False,
) )
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments # get what we've streamed so far for arguments
# for the current tool # for the current tool

View File

@@ -479,21 +479,23 @@ class Qwen3CoderToolParser(ToolParser):
self.header_sent = True self.header_sent = True
self.in_function = True self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when # Always append — each tool call is a separate
# we detect a tool call. This ensures # invocation even if the function name is the same
# finish_reason="tool_calls" even if parsing isn't complete # (e.g. two consecutive "read" calls).
already_added = any(
tool.get("name") == self.current_function_name
for tool in self.prev_tool_call_arr
)
if not already_added:
self.prev_tool_call_arr.append( self.prev_tool_call_arr.append(
{ {
"name": self.current_function_name, "name": self.current_function_name,
"arguments": "{}", # Placeholder, will be updated later "arguments": "{}",
} }
) )
# Initialize streamed args tracking for this tool.
# The serving layer reads streamed_args_for_tool to
# compute remaining arguments at stream end. Without
# this, IndexError occurs when the serving layer
# accesses streamed_args_for_tool[index].
self.streamed_args_for_tool.append("")
# Send header with function info # Send header with function info
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
@@ -511,9 +513,14 @@ class Qwen3CoderToolParser(ToolParser):
# We've sent header, now handle function body # We've sent header, now handle function body
if self.in_function: if self.in_function:
# Send opening brace if not sent yet # Always send opening brace first, regardless of whether
if not self.json_started and self.parameter_prefix not in delta_text: # parameter_prefix is in the current delta. With speculative
# decoding, a single delta may contain both the opening brace
# and parameter data; skipping "{" here would desync
# json_started from what was actually streamed.
if not self.json_started:
self.json_started = True self.json_started = True
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
@@ -523,97 +530,42 @@ class Qwen3CoderToolParser(ToolParser):
] ]
) )
# Make sure json_started is set if we're processing parameters # Find all parameter start positions in current tool_text
if not self.json_started:
self.json_started = True
# Check for function end in accumulated text
if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True
# Extract complete tool call to update
# prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix
)
func_content_end = tool_text.find(self.function_end_token, func_start)
if func_content_end != -1:
func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try:
parsed_tool = self._parse_xml_function_call(
func_content,
self.streaming_request.tools
if self.streaming_request
else None,
)
if parsed_tool:
# Update existing entry in
# prev_tool_call_arr with complete args
for i, tool in enumerate(self.prev_tool_call_arr):
if tool.get("name") == parsed_tool.function.name:
args = parsed_tool.function.arguments
self.prev_tool_call_arr[i]["arguments"] = args
break
except Exception:
pass # Ignore parsing errors during streaming
result = DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments="}"),
)
]
)
# Reset state for next tool
self.in_function = False
self.json_closed = True
self.accumulated_params = {}
return result
# Look for parameters
# Find all parameter starts
param_starts = [] param_starts = []
idx = 0 search_idx = 0
while True: while True:
idx = tool_text.find(self.parameter_prefix, idx) search_idx = tool_text.find(self.parameter_prefix, search_idx)
if idx == -1: if search_idx == -1:
break break
param_starts.append(idx) param_starts.append(search_idx)
idx += len(self.parameter_prefix) search_idx += len(self.parameter_prefix)
# Check if we should start a new parameter # Process ALL complete params in a loop (spec decode fix).
if ( # With speculative decoding a single delta can deliver
not self.in_param # multiple complete parameters at once. The old single-pass
and self.param_count < len(param_starts) # code would process one and ``return None`` if the next was
and len(param_starts) > self.param_count # incomplete — skipping any already-complete params that
): # preceded it. Using a loop with ``break`` instead ensures
# Process the next parameter # we emit every complete parameter before yielding control.
json_fragments = []
while not self.in_param and self.param_count < len(param_starts):
param_idx = param_starts[self.param_count] param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix) param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:] remaining = tool_text[param_start:]
if ">" in remaining: if ">" not in remaining:
# We have the complete parameter name break
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end] name_end = remaining.find(">")
current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1 value_start = param_start + name_end + 1
value_text = tool_text[value_start:] value_text = tool_text[value_start:]
if value_text.startswith("\n"): if value_text.startswith("\n"):
value_text = value_text[1:] value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token) param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1: if param_end_idx == -1:
# No closing tag, look for next parameter or
# function end
next_param_idx = value_text.find(self.parameter_prefix) next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.function_end_token) func_end_idx = value_text.find(self.function_end_token)
@@ -624,160 +576,129 @@ class Qwen3CoderToolParser(ToolParser):
elif func_end_idx != -1: elif func_end_idx != -1:
param_end_idx = func_end_idx param_end_idx = func_end_idx
else: else:
# Neither found, check if tool call is complete # Fallback for malformed XML where </function>
if self.tool_call_end_token in tool_text: # is missing. Use </tool_call> as a delimiter
# Tool call is complete, so parameter # if present in the value so we don't include
# must be complete too. Use all # the closing tag as part of the param value.
# remaining text before function end tool_end_in_value = value_text.find(self.tool_call_end_token)
param_end_idx = len(value_text) if tool_end_in_value != -1:
param_end_idx = tool_end_in_value
else: else:
# Still streaming, wait for more content # Parameter incomplete — break so we still
return None # emit any fragments accumulated by earlier
# loop iterations.
break
if param_end_idx == -1:
break
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx] param_value = value_text[:param_end_idx]
if param_value.endswith("\n"): if param_value.endswith("\n"):
param_value = param_value[:-1] param_value = param_value[:-1]
# Store raw value for later processing self.current_param_name = current_param_name
self.accumulated_params[self.current_param_name] = param_value self.accumulated_params[current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config( param_config = self._get_arguments_config(
self.current_function_name or "", self.current_function_name or "",
self.streaming_request.tools self.streaming_request.tools if self.streaming_request else None,
if self.streaming_request
else None,
) )
# Convert param value to appropriate type
converted_value = self._convert_param_value( converted_value = self._convert_param_value(
param_value, param_value,
self.current_param_name, current_param_name,
param_config, param_config,
self.current_function_name or "", self.current_function_name or "",
) )
# Build JSON fragment based on the converted type serialized_value = json.dumps(converted_value, ensure_ascii=False)
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)
if self.param_count == 0: if self.param_count == 0:
json_fragment = ( json_fragment = f'"{current_param_name}": {serialized_value}'
f'"{self.current_param_name}": {serialized_value}'
)
else: else:
json_fragment = ( json_fragment = f', "{current_param_name}": {serialized_value}'
f', "{self.current_param_name}": {serialized_value}'
)
self.param_count += 1 self.param_count += 1
json_fragments.append(json_fragment)
if json_fragments:
combined = "".join(json_fragments)
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += combined
else:
logger.warning(
"streamed_args_for_tool out of sync: index=%d len=%d",
self.current_tool_index,
len(self.streamed_args_for_tool),
)
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
index=self.current_tool_index, index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment), function=DeltaFunctionCall(arguments=combined),
) )
] ]
) )
# Continue parameter value - Not used in the current implementation # Check for function end AFTER processing parameters.
# since we process complete parameters above # This ordering is critical: with speculative decoding a
if self.in_param: # burst can deliver the final parameter value together with
if self.parameter_end_token in delta_text: # </function>. If the close check ran first it would emit
# End of parameter # "}" and set in_function=False before the parameter loop
end_idx = delta_text.find(self.parameter_end_token) # ever ran, causing the parameter to be silently dropped.
value_chunk = delta_text[:end_idx] if not self.json_closed and self.function_end_token in tool_text:
self.json_closed = True
# Skip past > if at start func_start = tool_text.find(self.tool_call_prefix) + len(
if not self.current_param_value and ">" in value_chunk: self.tool_call_prefix
gt_idx = value_chunk.find(">") )
value_chunk = value_chunk[gt_idx + 1 :] func_content_end = tool_text.find(self.function_end_token, func_start)
if func_content_end != -1:
if not self.current_param_value and value_chunk.startswith("\n"): func_content = tool_text[func_start:func_content_end]
value_chunk = value_chunk[1:] try:
parsed_tool = self._parse_xml_function_call(
# Store complete value func_content,
full_value = self.current_param_value + value_chunk
self.accumulated_params[self.current_param_name] = full_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools self.streaming_request.tools
if self.streaming_request if self.streaming_request
else None, else None,
) )
if parsed_tool and self.current_tool_index < len(
# Convert the parameter value to the appropriate type self.prev_tool_call_arr
converted_value = self._convert_param_value( ):
full_value, self.prev_tool_call_arr[self.current_tool_index][
self.current_param_name or "", "arguments"
param_config, ] = parsed_tool.function.arguments
self.current_function_name or "", except Exception:
logger.debug(
"Failed to parse tool call during streaming: %s",
tool_text,
exc_info=True,
) )
# Serialize the converted value if self.current_tool_index < len(self.streamed_args_for_tool):
serialized_value = json.dumps(converted_value, ensure_ascii=False) self.streamed_args_for_tool[self.current_tool_index] += "}"
# Since we've been streaming the quoted version,
# we need to close it properly
# This is complex - for now just complete the value
self.in_param = False
self.current_param_value = ""
# Just close the current parameter string
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments='"'
), # Close the string quote
)
]
)
else: else:
# Continue accumulating value logger.warning(
value_chunk = delta_text "streamed_args_for_tool out of sync: index=%d len=%d",
self.current_tool_index,
# Handle first chunk after param name len(self.streamed_args_for_tool),
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]
if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (
json.dumps(self.current_param_value, ensure_ascii=False)[
1:-1
]
if self.current_param_value
else ""
) )
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value, ensure_ascii=False
)[1:-1]
delta_escaped = full_escaped[len(prev_escaped) :]
if delta_escaped: result = DeltaMessage(
return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
index=self.current_tool_index, index=self.current_tool_index,
function=DeltaFunctionCall( function=DeltaFunctionCall(arguments="}"),
arguments=delta_escaped
),
) )
] ]
) )
self.in_function = False
self.json_closed = True
self.accumulated_params = {}
return result
return None return None