[Bugfix]: Reasoning output bug according to the chat template change (#13025)

Signed-off-by: Ce Gao <cegao@tensorchord.ai>
This commit is contained in:
Ce Gao
2025-02-11 15:49:03 +08:00
committed by GitHub
parent 78a141d768
commit fc6485d277
3 changed files with 127 additions and 43 deletions

View File

@@ -67,6 +67,8 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
]):
return None
# Check if <think> is present in previous or delta.
# Keep compatibility with models that don't generate <think> tokens.
if self.think_start_token_id in previous_token_ids:
if self.think_end_token_id in delta_token_ids:
# <think> in previous, </think> in delta,
@@ -85,7 +87,6 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
elif self.think_start_token_id in delta_token_ids:
logger.info(delta_text)
if self.think_end_token_id in delta_token_ids:
# <think> in delta, </think> in delta, extract reasoning content
start_index = delta_text.find(self.think_start_token)
@@ -101,35 +102,46 @@ class DeepSeekR1ReasoningParser(ReasoningParser):
# reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
else:
# No <think> in previous or delta, reasoning content continues.
return DeltaMessage(content=delta_text)
# No <think> in previous or delta, also need to check for </think>.
# Because the model may have generated </think> without <think>
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token_id in delta_token_ids:
# </think> in delta with more tokens,
# extract reasoning content and content
end_index = delta_text.find(self.think_end_token)
reasoning_content = delta_text[:end_index]
content = delta_text[end_index + len(self.think_end_token):]
return DeltaMessage(reasoning_content=reasoning_content,
content=content if content else None)
elif self.think_end_token_id in previous_token_ids:
# </think> in previous, thinking content ends
return DeltaMessage(content=delta_text)
else:
# no </think> in previous or delta, reasoning content continues
return DeltaMessage(reasoning_content=delta_text)
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> Tuple[Optional[str], Optional[str]]:
# Check if the model output contains the <think> tokens.
if (self.think_start_token not in model_output
or self.think_end_token not in model_output):
# DeepSeek R1 doesn't generate <think> now.
# Thus we assume the reasoning content is always at the start.
# Ref https://huggingface.co/deepseek-ai/DeepSeek-R1/commit/8a58a132790c9935686eb97f042afa8013451c9f
if self.think_end_token not in model_output:
return None, model_output
else:
# Add a start token if it's missing to keep compatibility.
if self.think_start_token not in model_output:
model_output = f"{self.think_start_token}{model_output}"
# Use a regex to find the reasoning content
reasoning_content = self.reasoning_regex.findall(model_output)[0]
# Remove the reasoning content from the model output
# Although deepseek's <think> token is always at the
# beginning of the line, we cannot guarantee that the
# other models will follow this convention.
# Therefore, we need to add :start_index.
start_index = model_output.find(self.think_start_token)
if start_index != -1:
end_index = start_index + len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
model_output = model_output[:start_index] + \
model_output[end_index:]
end_index = len(
f"{self.think_start_token}{reasoning_content}{self.think_end_token}"
)
final_output = model_output[end_index:]
if len(model_output) == 0:
return reasoning_content, None
if len(final_output) == 0:
return reasoning_content, None
return reasoning_content, model_output
return reasoning_content, final_output