[Tool Parser] Fix Qwen3Coder streaming parameter loss with speculative decode (#35615)

Signed-off-by: Martin Vit <martin@voipmonitor.org>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Martin Vit
2026-03-03 02:40:37 +01:00
committed by GitHub
parent 168ee03e1c
commit 8ebd872f50
2 changed files with 175 additions and 244 deletions

View File

@@ -1249,13 +1249,23 @@ class OpenAIServingChat(OpenAIServing):
) )
# get the expected call based on partial JSON # get the expected call based on partial JSON
# parsing which "autocompletes" the JSON # parsing which "autocompletes" the JSON.
expected_call = json.dumps( # Tool parsers (e.g. Qwen3Coder) store
tool_parser.prev_tool_call_arr[index].get( # arguments as a JSON string in
"arguments", {} # prev_tool_call_arr. Calling json.dumps()
), # on an already-serialized string would
ensure_ascii=False, # double-serialize it (e.g. '{"k":1}' becomes
# '"{\\"k\\":1}"'), which then causes the
# replace() below to fail and append the
# entire double-serialized string as a
# spurious final delta.
args = tool_parser.prev_tool_call_arr[index].get(
"arguments", {}
) )
if isinstance(args, str):
expected_call = args
else:
expected_call = json.dumps(args, ensure_ascii=False)
# get what we've streamed so far for arguments # get what we've streamed so far for arguments
# for the current tool # for the current tool

View File

@@ -479,20 +479,22 @@ class Qwen3CoderToolParser(ToolParser):
self.header_sent = True self.header_sent = True
self.in_function = True self.in_function = True
# IMPORTANT: Add to prev_tool_call_arr immediately when # Always append — each tool call is a separate
# we detect a tool call. This ensures # invocation even if the function name is the same
# finish_reason="tool_calls" even if parsing isn't complete # (e.g. two consecutive "read" calls).
already_added = any( self.prev_tool_call_arr.append(
tool.get("name") == self.current_function_name {
for tool in self.prev_tool_call_arr "name": self.current_function_name,
"arguments": "{}",
}
) )
if not already_added:
self.prev_tool_call_arr.append( # Initialize streamed args tracking for this tool.
{ # The serving layer reads streamed_args_for_tool to
"name": self.current_function_name, # compute remaining arguments at stream end. Without
"arguments": "{}", # Placeholder, will be updated later # this, IndexError occurs when the serving layer
} # accesses streamed_args_for_tool[index].
) self.streamed_args_for_tool.append("")
# Send header with function info # Send header with function info
return DeltaMessage( return DeltaMessage(
@@ -511,9 +513,14 @@ class Qwen3CoderToolParser(ToolParser):
# We've sent header, now handle function body # We've sent header, now handle function body
if self.in_function: if self.in_function:
# Send opening brace if not sent yet # Always send opening brace first, regardless of whether
if not self.json_started and self.parameter_prefix not in delta_text: # parameter_prefix is in the current delta. With speculative
# decoding, a single delta may contain both the opening brace
# and parameter data; skipping "{" here would desync
# json_started from what was actually streamed.
if not self.json_started:
self.json_started = True self.json_started = True
self.streamed_args_for_tool[self.current_tool_index] += "{"
return DeltaMessage( return DeltaMessage(
tool_calls=[ tool_calls=[
DeltaToolCall( DeltaToolCall(
@@ -523,25 +530,133 @@ class Qwen3CoderToolParser(ToolParser):
] ]
) )
# Make sure json_started is set if we're processing parameters # Find all parameter start positions in current tool_text
if not self.json_started: param_starts = []
self.json_started = True search_idx = 0
while True:
search_idx = tool_text.find(self.parameter_prefix, search_idx)
if search_idx == -1:
break
param_starts.append(search_idx)
search_idx += len(self.parameter_prefix)
# Check for function end in accumulated text # Process ALL complete params in a loop (spec decode fix).
# With speculative decoding a single delta can deliver
# multiple complete parameters at once. The old single-pass
# code would process one and ``return None`` if the next was
# incomplete — skipping any already-complete params that
# preceded it. Using a loop with ``break`` instead ensures
# we emit every complete parameter before yielding control.
json_fragments = []
while not self.in_param and self.param_count < len(param_starts):
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" not in remaining:
break
name_end = remaining.find(">")
current_param_name = remaining[:name_end]
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.function_end_token)
if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Fallback for malformed XML where </function>
# is missing. Use </tool_call> as a delimiter
# if present in the value so we don't include
# the closing tag as part of the param value.
tool_end_in_value = value_text.find(self.tool_call_end_token)
if tool_end_in_value != -1:
param_end_idx = tool_end_in_value
else:
# Parameter incomplete — break so we still
# emit any fragments accumulated by earlier
# loop iterations.
break
if param_end_idx == -1:
break
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
self.current_param_name = current_param_name
self.accumulated_params[current_param_name] = param_value
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools if self.streaming_request else None,
)
converted_value = self._convert_param_value(
param_value,
current_param_name,
param_config,
self.current_function_name or "",
)
serialized_value = json.dumps(converted_value, ensure_ascii=False)
if self.param_count == 0:
json_fragment = f'"{current_param_name}": {serialized_value}'
else:
json_fragment = f', "{current_param_name}": {serialized_value}'
self.param_count += 1
json_fragments.append(json_fragment)
if json_fragments:
combined = "".join(json_fragments)
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += combined
else:
logger.warning(
"streamed_args_for_tool out of sync: index=%d len=%d",
self.current_tool_index,
len(self.streamed_args_for_tool),
)
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=combined),
)
]
)
# Check for function end AFTER processing parameters.
# This ordering is critical: with speculative decoding a
# burst can deliver the final parameter value together with
# </function>. If the close check ran first it would emit
# "}" and set in_function=False before the parameter loop
# ever ran, causing the parameter to be silently dropped.
if not self.json_closed and self.function_end_token in tool_text: if not self.json_closed and self.function_end_token in tool_text:
# Close JSON
self.json_closed = True self.json_closed = True
# Extract complete tool call to update
# prev_tool_call_arr with final arguments
# Find the function content
func_start = tool_text.find(self.tool_call_prefix) + len( func_start = tool_text.find(self.tool_call_prefix) + len(
self.tool_call_prefix self.tool_call_prefix
) )
func_content_end = tool_text.find(self.function_end_token, func_start) func_content_end = tool_text.find(self.function_end_token, func_start)
if func_content_end != -1: if func_content_end != -1:
func_content = tool_text[func_start:func_content_end] func_content = tool_text[func_start:func_content_end]
# Parse to get the complete arguments
try: try:
parsed_tool = self._parse_xml_function_call( parsed_tool = self._parse_xml_function_call(
func_content, func_content,
@@ -549,16 +664,27 @@ class Qwen3CoderToolParser(ToolParser):
if self.streaming_request if self.streaming_request
else None, else None,
) )
if parsed_tool: if parsed_tool and self.current_tool_index < len(
# Update existing entry in self.prev_tool_call_arr
# prev_tool_call_arr with complete args ):
for i, tool in enumerate(self.prev_tool_call_arr): self.prev_tool_call_arr[self.current_tool_index][
if tool.get("name") == parsed_tool.function.name: "arguments"
args = parsed_tool.function.arguments ] = parsed_tool.function.arguments
self.prev_tool_call_arr[i]["arguments"] = args
break
except Exception: except Exception:
pass # Ignore parsing errors during streaming logger.debug(
"Failed to parse tool call during streaming: %s",
tool_text,
exc_info=True,
)
if self.current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self.current_tool_index] += "}"
else:
logger.warning(
"streamed_args_for_tool out of sync: index=%d len=%d",
self.current_tool_index,
len(self.streamed_args_for_tool),
)
result = DeltaMessage( result = DeltaMessage(
tool_calls=[ tool_calls=[
@@ -569,215 +695,10 @@ class Qwen3CoderToolParser(ToolParser):
] ]
) )
# Reset state for next tool
self.in_function = False self.in_function = False
self.json_closed = True self.json_closed = True
self.accumulated_params = {} self.accumulated_params = {}
return result return result
# Look for parameters
# Find all parameter starts
param_starts = []
idx = 0
while True:
idx = tool_text.find(self.parameter_prefix, idx)
if idx == -1:
break
param_starts.append(idx)
idx += len(self.parameter_prefix)
# Check if we should start a new parameter
if (
not self.in_param
and self.param_count < len(param_starts)
and len(param_starts) > self.param_count
):
# Process the next parameter
param_idx = param_starts[self.param_count]
param_start = param_idx + len(self.parameter_prefix)
remaining = tool_text[param_start:]
if ">" in remaining:
# We have the complete parameter name
name_end = remaining.find(">")
self.current_param_name = remaining[:name_end]
# Find the parameter value
value_start = param_start + name_end + 1
value_text = tool_text[value_start:]
if value_text.startswith("\n"):
value_text = value_text[1:]
# Find where this parameter ends
param_end_idx = value_text.find(self.parameter_end_token)
if param_end_idx == -1:
# No closing tag, look for next parameter or
# function end
next_param_idx = value_text.find(self.parameter_prefix)
func_end_idx = value_text.find(self.function_end_token)
if next_param_idx != -1 and (
func_end_idx == -1 or next_param_idx < func_end_idx
):
param_end_idx = next_param_idx
elif func_end_idx != -1:
param_end_idx = func_end_idx
else:
# Neither found, check if tool call is complete
if self.tool_call_end_token in tool_text:
# Tool call is complete, so parameter
# must be complete too. Use all
# remaining text before function end
param_end_idx = len(value_text)
else:
# Still streaming, wait for more content
return None
if param_end_idx != -1:
# Complete parameter found
param_value = value_text[:param_end_idx]
if param_value.endswith("\n"):
param_value = param_value[:-1]
# Store raw value for later processing
self.accumulated_params[self.current_param_name] = param_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools
if self.streaming_request
else None,
)
# Convert param value to appropriate type
converted_value = self._convert_param_value(
param_value,
self.current_param_name,
param_config,
self.current_function_name or "",
)
# Build JSON fragment based on the converted type
# Use json.dumps to properly serialize the value
serialized_value = json.dumps(
converted_value, ensure_ascii=False
)
if self.param_count == 0:
json_fragment = (
f'"{self.current_param_name}": {serialized_value}'
)
else:
json_fragment = (
f', "{self.current_param_name}": {serialized_value}'
)
self.param_count += 1
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(arguments=json_fragment),
)
]
)
# Continue parameter value - Not used in the current implementation
# since we process complete parameters above
if self.in_param:
if self.parameter_end_token in delta_text:
# End of parameter
end_idx = delta_text.find(self.parameter_end_token)
value_chunk = delta_text[:end_idx]
# Skip past > if at start
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]
if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]
# Store complete value
full_value = self.current_param_value + value_chunk
self.accumulated_params[self.current_param_name] = full_value
# Get parameter configuration for type conversion
param_config = self._get_arguments_config(
self.current_function_name or "",
self.streaming_request.tools
if self.streaming_request
else None,
)
# Convert the parameter value to the appropriate type
converted_value = self._convert_param_value(
full_value,
self.current_param_name or "",
param_config,
self.current_function_name or "",
)
# Serialize the converted value
serialized_value = json.dumps(converted_value, ensure_ascii=False)
# Since we've been streaming the quoted version,
# we need to close it properly
# This is complex - for now just complete the value
self.in_param = False
self.current_param_value = ""
# Just close the current parameter string
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments='"'
), # Close the string quote
)
]
)
else:
# Continue accumulating value
value_chunk = delta_text
# Handle first chunk after param name
if not self.current_param_value and ">" in value_chunk:
gt_idx = value_chunk.find(">")
value_chunk = value_chunk[gt_idx + 1 :]
if not self.current_param_value and value_chunk.startswith("\n"):
value_chunk = value_chunk[1:]
if value_chunk:
# Stream the escaped delta
prev_escaped = (
json.dumps(self.current_param_value, ensure_ascii=False)[
1:-1
]
if self.current_param_value
else ""
)
self.current_param_value += value_chunk
full_escaped = json.dumps(
self.current_param_value, ensure_ascii=False
)[1:-1]
delta_escaped = full_escaped[len(prev_escaped) :]
if delta_escaped:
return DeltaMessage(
tool_calls=[
DeltaToolCall(
index=self.current_tool_index,
function=DeltaFunctionCall(
arguments=delta_escaped
),
)
]
)
return None return None