[Model] IBM Granite 3.1 (#11307)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
This commit is contained in:
Travis Johnson
2024-12-18 20:27:24 -07:00
committed by GitHub
parent 5a9da2e6e9
commit 17ca964273
4 changed files with 27 additions and 6 deletions

View File

@@ -35,13 +35,18 @@ class GraniteToolParser(ToolParser):
def __init__(self, tokenizer: AnyTokenizer):
super().__init__(tokenizer)
# for granite 3.0, the token `<|tool_call|>`
self.bot_token = "<|tool_call|>"
# for granite 3.1, the string `<tool_call>`
self.bot_string = "<tool_call>"
def extract_tool_calls(
self, model_output: str,
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
# remove whitespace and the BOT token if it exists
stripped = model_output.strip().removeprefix(self.bot_token).lstrip()
stripped = model_output.strip()\
.removeprefix(self.bot_token)\
.removeprefix(self.bot_string)\
.lstrip()
if not stripped or stripped[0] != '[':
return ExtractedToolCallInformation(tools_called=False,
tool_calls=[],
@@ -91,6 +96,9 @@ class GraniteToolParser(ToolParser):
if current_text[start_idx:].startswith(self.bot_token):
start_idx = consume_space(start_idx + len(self.bot_token),
current_text)
if current_text[start_idx:].startswith(self.bot_string):
start_idx = consume_space(start_idx + len(self.bot_string),
current_text)
if not current_text or start_idx >= len(current_text)\
or current_text[start_idx] != '[':
return DeltaMessage(content=delta_text)