[Bugfix]: Fix structured output in multi-turn gpt-oss (#34454)

Signed-off-by: Ben Browning <bbrownin@redhat.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Ben Browning
2026-02-13 14:12:48 -05:00
committed by GitHub
parent bfaa559305
commit fd267bc7b7
4 changed files with 29 additions and 1 deletions

View File

@@ -23,6 +23,7 @@ class TestGptOssStructuralTagsIntegration:
"""Create a mock tokenizer."""
tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
tokenizer.vocab = {"<|end|>": 6}
return tokenizer
@pytest.fixture

View File

@@ -17,7 +17,9 @@ def gpt_oss_tokenizer():
USER_MESSAGE_START = "<|start|>user<|message|>"
REASONING_SECTION_START = "<|end|><|start|>assistant<|channel|>analysis<|message|>"
ASSISTANT_CONTENT_START_PREFIX = "<|end|><|start|>assistant<|channel|>final"
END = "<|end|>"
ASSISTANT_START = "<|start|>assistant"
ASSISTANT_CONTENT_START_PREFIX = END + ASSISTANT_START + "<|channel|>final"
ASSISTANT_CONTENT_START_SUFFIX = "<|message|>"
ASSISTANT_CONTENT_START = (
ASSISTANT_CONTENT_START_PREFIX + ASSISTANT_CONTENT_START_SUFFIX
@@ -97,6 +99,20 @@ COMPLEX_CONTENT_2 = {
"is_reasoning_end": True,
}
MULTI_TURN_CONTENT = {
"output": USER_MESSAGE_START
+ "1st turn user message"
+ REASONING_SECTION_START
+ "1st turn reasoning"
+ ASSISTANT_CONTENT_START
+ "1st turn response"
+ END
+ USER_MESSAGE_START
+ "2nd turn user message"
+ END
+ ASSISTANT_START,
"is_reasoning_end": False,
}
TEST_CASES = [
BASIC_CONTENT,
BASIC_REASONING_ONLY,
@@ -106,6 +122,7 @@ TEST_CASES = [
COMPLEX_CONTENT_1,
COMPLEX_CONTENT_1_WITH_CONTENT,
COMPLEX_CONTENT_2,
MULTI_TURN_CONTENT,
]

View File

@@ -25,6 +25,7 @@ class TestGptOssReasoningParser:
"""Create a mock tokenizer for testing."""
tokenizer = Mock()
tokenizer.encode = Mock(return_value=[1, 2, 3, 4, 5])
tokenizer.vocab = {"<|end|>": 6}
return tokenizer
@pytest.fixture

View File

@@ -76,6 +76,9 @@ class GptOssReasoningParser(ReasoningParser):
"<|channel|>final"
)
self.reasoning_end_token_ids_suffix = self.model_tokenizer.encode("<|message|>")
# We also need to check for the <|end|> token to avoid false positives from
# previous messages in multi-turn conversations.
self.eom_token_id = self.model_tokenizer.vocab["<|end|>"]
self.reasoning_max_num_between_tokens = 20
def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
@@ -86,6 +89,12 @@ class GptOssReasoningParser(ReasoningParser):
# Check if the end sequence is present in the input_ids.
# We search from the end of input_ids to find the last match.
for i in range(len(input_ids) - len(end_token_ids_prefix), -1, -1):
if input_ids[i] == self.eom_token_id:
# We looped backwards far enough to find the end of a previous message,
# which means we have searched the entirety of the current message
# and can exit early without searching further back into prior
# messages of the conversation.
return False
if input_ids[i : i + len(end_token_ids_prefix)] == end_token_ids_prefix:
# We have found the prefix, now we look for the suffix after the prefix.
suffix_start = i + len(end_token_ids_prefix)