[Bugfix] Support missing tool parameters in mistral tokenizer (#12884)
Signed-off-by: Florian Greinacher <florian.greinacher@siemens.com>
This commit is contained in:
committed by
GitHub
parent
2c0f58203c
commit
cb080f32e3
@@ -104,6 +104,42 @@ def find_tokenizer_file(files: List[str]):
|
||||
return matched_files[0]
|
||||
|
||||
|
||||
def make_mistral_chat_completion_request(
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[List[Dict[str,
|
||||
Any]]] = None) -> "ChatCompletionRequest":
|
||||
last_message = cast(Dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
last_message = cast(Dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
# mistral-common requires AssistantMessage content to be string [1].
|
||||
#
|
||||
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
||||
for message in messages:
|
||||
if message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = "\n".join(chunk.get("text") for chunk in content)
|
||||
message["content"] = content
|
||||
|
||||
# The Mistral client, in comparison to the OpenAI client, requires the
|
||||
# "parameters" dict to be present, even if it's empty.
|
||||
if tools:
|
||||
for function in [
|
||||
tool["function"] for tool in tools
|
||||
if tool["type"] == "function"
|
||||
]:
|
||||
function.setdefault("parameters", {})
|
||||
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
return ChatCompletionRequest(messages=messages,
|
||||
tools=tools) # type: ignore[type-var]
|
||||
|
||||
|
||||
class MistralTokenizer:
|
||||
|
||||
def __init__(self, tokenizer: "PublicMistralTokenizer") -> None:
|
||||
@@ -283,27 +319,10 @@ class MistralTokenizer:
|
||||
|
||||
def apply_chat_template(self,
|
||||
messages: List["ChatCompletionMessageParam"],
|
||||
tools: Optional[Dict[str, Any]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**kwargs) -> List[int]:
|
||||
|
||||
last_message = cast(Dict[str, Any], messages[-1])
|
||||
if last_message["role"] == "assistant":
|
||||
last_message["prefix"] = True
|
||||
|
||||
from mistral_common.protocol.instruct.request import (
|
||||
ChatCompletionRequest)
|
||||
|
||||
# mistral-common requires AssistantMessage content to be string [1].
|
||||
#
|
||||
# [1]: https://github.com/mistralai/mistral-common/blob/f4a06998b75ed78bbf5aaf569590b772ea26c9f6/src/mistral_common/protocol/instruct/messages.py#L80
|
||||
for message in messages:
|
||||
if message.get("role") == "assistant":
|
||||
content = message.get("content")
|
||||
if isinstance(content, list):
|
||||
content = "\n".join(chunk.get("text") for chunk in content)
|
||||
message["content"] = content
|
||||
request = ChatCompletionRequest(messages=messages,
|
||||
tools=tools) # type: ignore[type-var]
|
||||
request = make_mistral_chat_completion_request(messages, tools)
|
||||
encoded = self.mistral.encode_chat_completion(request)
|
||||
|
||||
# encode-decode to get clean prompt
|
||||
|
||||
Reference in New Issue
Block a user