[mistral_common] Add v11 tokenizer (#19193)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
9bc8bb07cf
commit
f20f9f063b
@@ -44,11 +44,17 @@ class MistralToolCall(ToolCall):
|
|||||||
return id.isalnum() and len(id) == 9
|
return id.isalnum() and len(id) == 9
|
||||||
|
|
||||||
|
|
||||||
|
def _is_fn_name_regex_support(model_tokenizer: AnyTokenizer) -> bool:
|
||||||
|
return isinstance(model_tokenizer, MistralTokenizer) \
|
||||||
|
and model_tokenizer.version >= 11
|
||||||
|
|
||||||
|
|
||||||
@ToolParserManager.register_module("mistral")
|
@ToolParserManager.register_module("mistral")
|
||||||
class MistralToolParser(ToolParser):
|
class MistralToolParser(ToolParser):
|
||||||
"""
|
"""
|
||||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
Tool call parser for Mistral 7B Instruct v0.3, intended for use with
|
||||||
examples/tool_chat_template_mistral.jinja template.
|
- [`mistral_common`](https://github.com/mistralai/mistral-common/)
|
||||||
|
- the examples/tool_chat_template_mistral.jinja template.
|
||||||
|
|
||||||
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
Used when --enable-auto-tool-choice --tool-call-parser mistral are all set
|
||||||
"""
|
"""
|
||||||
@@ -70,6 +76,12 @@ class MistralToolParser(ToolParser):
|
|||||||
self.bot_token = "[TOOL_CALLS]"
|
self.bot_token = "[TOOL_CALLS]"
|
||||||
self.bot_token_id = self.vocab.get(self.bot_token)
|
self.bot_token_id = self.vocab.get(self.bot_token)
|
||||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||||
|
if _is_fn_name_regex_support(self.model_tokenizer):
|
||||||
|
self.fn_name_regex = re.compile(r'([a-zA-Z0-9_-]+)(\{.*?\})',
|
||||||
|
re.DOTALL)
|
||||||
|
else:
|
||||||
|
self.fn_name_regex = None
|
||||||
|
|
||||||
if self.bot_token_id is None:
|
if self.bot_token_id is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Mistral Tool Parser could not locate the tool call token in "
|
"Mistral Tool Parser could not locate the tool call token in "
|
||||||
@@ -109,11 +121,25 @@ class MistralToolParser(ToolParser):
|
|||||||
tool_content = model_output.replace(self.bot_token, "").strip()
|
tool_content = model_output.replace(self.bot_token, "").strip()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# we first try to directly load the json as parsing very nested
|
# we first try to directly load the json as parsing very nested
|
||||||
# jsons is difficult
|
# jsons is difficult
|
||||||
try:
|
try:
|
||||||
function_call_arr = json.loads(tool_content)
|
if self.fn_name_regex:
|
||||||
|
matches = self.fn_name_regex.findall(tool_content)
|
||||||
|
|
||||||
|
function_call_arr = []
|
||||||
|
for match in matches:
|
||||||
|
fn_name = match[0]
|
||||||
|
args = match[1]
|
||||||
|
|
||||||
|
# fn_name is encoded outside serialized json dump
|
||||||
|
# only arguments are serialized
|
||||||
|
function_call_arr.append({
|
||||||
|
"name": fn_name,
|
||||||
|
"arguments": json.loads(args)
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
function_call_arr = json.loads(tool_content)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# use a regex to find the part corresponding to the tool call.
|
# use a regex to find the part corresponding to the tool call.
|
||||||
# NOTE: This use case should not happen if the model is trained
|
# NOTE: This use case should not happen if the model is trained
|
||||||
|
|||||||
@@ -187,6 +187,8 @@ class MistralTokenizer(TokenizerBase):
|
|||||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||||
self.mistral = tokenizer
|
self.mistral = tokenizer
|
||||||
self.instruct = tokenizer.instruct_tokenizer
|
self.instruct = tokenizer.instruct_tokenizer
|
||||||
|
_mistral_version_str = self.instruct.tokenizer.version.value
|
||||||
|
self.version: int = int(_mistral_version_str.split("v")[-1])
|
||||||
|
|
||||||
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
tokenizer_ = tokenizer.instruct_tokenizer.tokenizer
|
||||||
from mistral_common.tokens.tokenizers.tekken import (
|
from mistral_common.tokens.tokenizers.tekken import (
|
||||||
|
|||||||
Reference in New Issue
Block a user