[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model (#8405)
This commit is contained in:
@@ -53,6 +53,7 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_tokenization import (
|
||||
OpenAIServingTokenization)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
|
||||
@@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
valide_tool_parses = ToolParserManager.tool_parsers.keys()
|
||||
if args.enable_auto_tool_choice \
|
||||
and args.tool_call_parser not in valide_tool_parses:
|
||||
raise KeyError(f"invalid tool call parser: {args.tool_call_parser} "
|
||||
f"(chose from {{ {','.join(valide_tool_parses)} }})")
|
||||
|
||||
temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
temp_socket.bind(("", args.port))
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import List, Optional, Sequence, Union
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
PromptAdapterPath)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"Enable auto tool choice for supported models. Use --tool-call-parser"
|
||||
"to specify which parser to use")
|
||||
|
||||
valid_tool_parsers = ToolParserManager.tool_parsers.keys()
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["mistral", "hermes", "llama3_json"],
|
||||
metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in "
|
||||
"--tool-parser-plugin",
|
||||
default=None,
|
||||
help=
|
||||
"Select the tool call parser depending on the model that you're using."
|
||||
" This is used to parse the model-generated tool call into OpenAI API "
|
||||
"format. Required for --enable-auto-tool-choice.")
|
||||
|
||||
parser.add_argument(
|
||||
"--tool-parser-plugin",
|
||||
type=str,
|
||||
default="",
|
||||
help=
|
||||
"Special the tool parser plugin write to parse the model-generated tool"
|
||||
" into OpenAI API format, the name register in this plugin can be used "
|
||||
"in --tool-call-parser.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
|
||||
@@ -29,10 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
|
||||
OpenAIServing,
|
||||
PromptAdapterPath,
|
||||
TextTokensPrompt)
|
||||
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
|
||||
Llama3JsonToolParser,
|
||||
MistralToolParser,
|
||||
ToolParser)
|
||||
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
@@ -82,15 +79,13 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
|
||||
if self.enable_auto_tools:
|
||||
if tool_parser == "mistral":
|
||||
self.tool_parser = MistralToolParser
|
||||
elif tool_parser == "hermes":
|
||||
self.tool_parser = Hermes2ProToolParser
|
||||
elif tool_parser == "llama3_json":
|
||||
self.tool_parser = Llama3JsonToolParser
|
||||
else:
|
||||
try:
|
||||
self.tool_parser = ToolParserManager.get_tool_parser(
|
||||
tool_parser)
|
||||
except Exception as e:
|
||||
raise TypeError("Error: --enable-auto-tool-choice requires "
|
||||
"--tool-call-parser")
|
||||
f"tool_parser:'{tool_parser}' which has not "
|
||||
"been registered") from e
|
||||
|
||||
async def create_chat_completion(
|
||||
self,
|
||||
@@ -187,6 +182,10 @@ class OpenAIServingChat(OpenAIServing):
|
||||
raw_request.state.request_metadata = request_metadata
|
||||
|
||||
try:
|
||||
if self.enable_auto_tools and self.tool_parser:
|
||||
request = self.tool_parser(tokenizer).adjust_request(
|
||||
request=request)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
@@ -282,11 +281,11 @@ class OpenAIServingChat(OpenAIServing):
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
|
||||
num_prompt_tokens = 0
|
||||
|
||||
tool_parser: Optional[ToolParser] = self.tool_parser(
|
||||
tokenizer) if self.tool_parser else None
|
||||
tool_parsers: List[Optional[ToolParser]] = [
|
||||
self.tool_parser(tokenizer) if self.tool_parser else None
|
||||
] * num_choices
|
||||
|
||||
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
|
||||
tool_choice_function_name = request.tool_choice.function.name
|
||||
@@ -324,7 +323,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
# NOTE num_choices defaults to 1 so this usually executes
|
||||
# once per request
|
||||
for i in range(num_choices):
|
||||
|
||||
tool_parser = tool_parsers[i]
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
@@ -399,6 +398,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
tool_parser = tool_parsers[i]
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
@@ -446,7 +446,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
delta_text=delta_text,
|
||||
previous_token_ids=previous_token_ids,
|
||||
current_token_ids=current_token_ids,
|
||||
delta_token_ids=output.token_ids))
|
||||
delta_token_ids=output.token_ids,
|
||||
request=request))
|
||||
|
||||
# update the previous values for the next iteration
|
||||
previous_texts[i] = current_text
|
||||
@@ -685,7 +686,8 @@ class OpenAIServingChat(OpenAIServing):
|
||||
and self.tool_parser:
|
||||
|
||||
tool_parser = self.tool_parser(tokenizer)
|
||||
tool_call_info = tool_parser.extract_tool_calls(output.text)
|
||||
tool_call_info = tool_parser.extract_tool_calls(
|
||||
output.text, request=request)
|
||||
tools_called = tool_call_info.tools_called
|
||||
if tool_call_info.tools_called:
|
||||
message = ChatMessage(role=role,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from .abstract_tool_parser import ToolParser
|
||||
from .abstract_tool_parser import ToolParser, ToolParserManager
|
||||
from .hermes_tool_parser import Hermes2ProToolParser
|
||||
from .internlm2_tool_parser import Internlm2ToolParser
|
||||
from .llama_tool_parser import Llama3JsonToolParser
|
||||
from .mistral_tool_parser import MistralToolParser
|
||||
|
||||
__all__ = [
|
||||
"ToolParser", "Hermes2ProToolParser", "MistralToolParser",
|
||||
"Llama3JsonToolParser"
|
||||
"ToolParser", "ToolParserManager", "Hermes2ProToolParser",
|
||||
"MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser"
|
||||
]
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import Dict, List, Sequence, Union
|
||||
import importlib
|
||||
import importlib.util
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ExtractedToolCallInformation)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -24,8 +29,16 @@ class ToolParser:
|
||||
|
||||
self.model_tokenizer = tokenizer
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
"""
|
||||
Static method that used to adjust the request parameters.
|
||||
"""
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Static method that should be implemented for extracting tool calls from
|
||||
a complete model-generated string.
|
||||
@@ -44,6 +57,7 @@ class ToolParser:
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
"""
|
||||
Instance method that should be implemented for extracting tool calls
|
||||
@@ -55,3 +69,86 @@ class ToolParser:
|
||||
raise NotImplementedError(
|
||||
"AbstractToolParser.extract_tool_calls_streaming has not been "
|
||||
"implemented!")
|
||||
|
||||
|
||||
class ToolParserManager:
|
||||
tool_parsers: Dict[str, Type] = {}
|
||||
|
||||
@classmethod
|
||||
def get_tool_parser(cls, name) -> Type:
|
||||
"""
|
||||
Get tool parser by name which is registered by `register_module`.
|
||||
|
||||
Raise a KeyError exception if the name is not registered.
|
||||
"""
|
||||
if name in cls.tool_parsers:
|
||||
return cls.tool_parsers[name]
|
||||
|
||||
raise KeyError(f"tool helper: '{name}' not found in tool_parsers")
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
module: Type,
|
||||
module_name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True) -> None:
|
||||
if not issubclass(module, ToolParser):
|
||||
raise TypeError(
|
||||
f'module must be subclass of ToolParser, but got {type(module)}'
|
||||
)
|
||||
if module_name is None:
|
||||
module_name = module.__name__
|
||||
if isinstance(module_name, str):
|
||||
module_name = [module_name]
|
||||
for name in module_name:
|
||||
if not force and name in cls.tool_parsers:
|
||||
existed_module = cls.tool_parsers[name]
|
||||
raise KeyError(f'{name} is already registered '
|
||||
f'at {existed_module.__module__}')
|
||||
cls.tool_parsers[name] = module
|
||||
|
||||
@classmethod
|
||||
def register_module(
|
||||
cls,
|
||||
name: Optional[Union[str, List[str]]] = None,
|
||||
force: bool = True,
|
||||
module: Union[Type, None] = None) -> Union[type, Callable]:
|
||||
"""
|
||||
Register module with the given name or name list. it can be used as a
|
||||
decoder(with module as None) or normal function(with module as not
|
||||
None).
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
# raise the error ahead of time
|
||||
if not (name is None or isinstance(name, str)
|
||||
or is_list_of(name, str)):
|
||||
raise TypeError(
|
||||
'name must be None, an instance of str, or a sequence of str, '
|
||||
f'but got {type(name)}')
|
||||
|
||||
# use it as a normal method: x.register_module(module=SomeClass)
|
||||
if module is not None:
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
# use it as a decorator: @x.register_module()
|
||||
def _register(module):
|
||||
cls._register_module(module=module, module_name=name, force=force)
|
||||
return module
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def import_tool_parser(cls, plugin_path: str) -> None:
|
||||
"""
|
||||
Import a user defined tool parser by the path of the tool parser define
|
||||
file.
|
||||
"""
|
||||
module_name = os.path.splitext(os.path.basename(plugin_path))[0]
|
||||
spec = importlib.util.spec_from_file_location(module_name, plugin_path)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error("load %s from %s failed.", module_name, plugin_path)
|
||||
return
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
@@ -5,12 +5,13 @@ from typing import Dict, List, Sequence, Union
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
@@ -20,6 +21,7 @@ from vllm.utils import random_uuid
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module("hermes")
|
||||
class Hermes2ProToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
@@ -57,8 +59,11 @@ class Hermes2ProToolParser(ToolParser):
|
||||
"Hermes 2 Pro Tool parser could not locate tool call start/end "
|
||||
"tokens in the tokenizer!")
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
|
||||
# sanity check; avoid unnecessary processing
|
||||
if self.tool_call_start_token not in model_output:
|
||||
@@ -114,6 +119,7 @@ class Hermes2ProToolParser(ToolParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
logger.debug("delta_text: %s", delta_text)
|
||||
|
||||
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
208
vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import json
|
||||
from typing import Dict, Sequence, Union
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ToolParserManager.register_module(["internlm"])
|
||||
class Internlm2ToolParser(ToolParser):
|
||||
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
self.position = 0
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
if request.tools and request.tool_choice != 'none':
|
||||
# do not skip special tokens because internlm use the special
|
||||
# tokens to indicated the start and end of the tool calls
|
||||
# information.
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def get_argments(self, obj):
|
||||
if "parameters" in obj:
|
||||
return obj.get("parameters")
|
||||
elif "arguments" in obj:
|
||||
return obj.get("arguments")
|
||||
return None
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
if '<|action_start|>' not in current_text:
|
||||
self.position = len(current_text)
|
||||
return DeltaMessage(content=delta_text)
|
||||
# if the tool call is sended, return a empty delta message
|
||||
# to make sure the finish_reason will be send correctly.
|
||||
if self.current_tool_id > 0:
|
||||
return DeltaMessage(content='')
|
||||
|
||||
last_pos = self.position
|
||||
if '<|action_start|><|plugin|>' not in current_text[last_pos:]:
|
||||
return None
|
||||
|
||||
new_delta = current_text[last_pos:]
|
||||
text, action = new_delta.split('<|action_start|><|plugin|>')
|
||||
|
||||
if len(text) > 0:
|
||||
self.position = self.position + len(text)
|
||||
return DeltaMessage(content=text)
|
||||
|
||||
action = action.strip()
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent \
|
||||
else Allow.ALL & ~Allow.STR
|
||||
|
||||
try:
|
||||
parsable_arr = action
|
||||
|
||||
# tool calls are generated in an object in inernlm2
|
||||
# it's not support parallel tool calls
|
||||
try:
|
||||
tool_call_arr: Dict = partial_json_parser.loads(
|
||||
parsable_arr, flags)
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
logger.debug('not enough tokens to parse into JSON yet')
|
||||
return None
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
if not self.current_tool_name_sent:
|
||||
function_name = tool_call_arr.get("name")
|
||||
if function_name:
|
||||
self.current_tool_id = self.current_tool_id + 1
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
type="function",
|
||||
id=f"chatcmpl-tool-{random_uuid()}",
|
||||
function=DeltaFunctionCall(
|
||||
name=function_name).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.current_tool_name_sent = True
|
||||
self.streamed_args_for_tool.append("")
|
||||
else:
|
||||
delta = None
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
prev_arguments = self.get_argments(
|
||||
self.prev_tool_call_arr[self.current_tool_id])
|
||||
cur_arguments = self.get_argments(tool_call_arr)
|
||||
|
||||
# not arguments generated
|
||||
if not cur_arguments and not prev_arguments:
|
||||
delta = None
|
||||
# will never happen
|
||||
elif not cur_arguments and prev_arguments:
|
||||
logger.error(
|
||||
"INVARIANT - impossible to have arguments reset "
|
||||
"mid-arguments")
|
||||
delta = None
|
||||
# first time to get parameters
|
||||
elif cur_arguments and not prev_arguments:
|
||||
cur_arguments_json = json.dumps(cur_arguments)
|
||||
|
||||
arguments_delta = cur_arguments_json[:cur_arguments_json.
|
||||
index(delta_text) +
|
||||
len(delta_text)]
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=arguments_delta).
|
||||
model_dump(exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += arguments_delta
|
||||
# both prev and cur parameters, send the increase parameters
|
||||
elif cur_arguments and prev_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
|
||||
argument_diff = extract_intermediate_diff(
|
||||
cur_args_json, prev_args_json)
|
||||
|
||||
delta = DeltaMessage(tool_calls=[
|
||||
DeltaToolCall(index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(
|
||||
arguments=argument_diff).model_dump(
|
||||
exclude_none=True))
|
||||
])
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id] += argument_diff
|
||||
|
||||
# check to see if the name is defined and has been sent. if so,
|
||||
# stream the name - otherwise keep waiting
|
||||
# finish by setting old and returning None as base case
|
||||
tool_call_arr["arguments"] = self.get_argments(tool_call_arr)
|
||||
self.prev_tool_call_arr = [tool_call_arr]
|
||||
return delta
|
||||
except Exception as e:
|
||||
logger.error("Error trying to handle streaming tool call: %s", e)
|
||||
logger.debug(
|
||||
"Skipping chunk as a result of tool streaming extraction "
|
||||
"error")
|
||||
return None
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
text = model_output
|
||||
tools = request.tools
|
||||
if '<|action_start|><|plugin|>' in text:
|
||||
text, action = text.split('<|action_start|><|plugin|>')
|
||||
action = action.split('<|action_end|>'.strip())[0]
|
||||
action = action[action.find('{'):]
|
||||
action_dict = json.loads(action)
|
||||
name, parameters = action_dict['name'], json.dumps(
|
||||
action_dict.get('parameters', action_dict.get('arguments',
|
||||
{})))
|
||||
|
||||
if not tools or name not in [t.function.name for t in tools]:
|
||||
ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
function=FunctionCall(name=name, arguments=parameters))
|
||||
]
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=text if len(text) > 0 else None)
|
||||
|
||||
return ExtractedToolCallInformation(tools_called=False,
|
||||
tool_calls=[],
|
||||
content=text)
|
||||
@@ -7,12 +7,13 @@ import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
@@ -41,6 +42,7 @@ def is_complete_json(input_str):
|
||||
return False
|
||||
|
||||
|
||||
@ToolParserManager.register_module("llama3_json")
|
||||
class Llama3JsonToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Llama 3.1 models intended for use with the
|
||||
@@ -64,8 +66,9 @@ class Llama3JsonToolParser(ToolParser):
|
||||
add_special_tokens=False)[0]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def extract_tool_calls(
|
||||
self, model_output: str,
|
||||
request: ChatCompletionRequest) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response.
|
||||
"""
|
||||
@@ -125,6 +128,7 @@ class Llama3JsonToolParser(ToolParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
if not (current_text.startswith(self.bot_token)
|
||||
|
||||
@@ -8,12 +8,13 @@ import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import Field
|
||||
|
||||
from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage,
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
DeltaFunctionCall, DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall, ToolCall)
|
||||
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
|
||||
ToolParser)
|
||||
ToolParser, ToolParserManager)
|
||||
from vllm.entrypoints.openai.tool_parsers.utils import (
|
||||
extract_intermediate_diff)
|
||||
from vllm.logger import init_logger
|
||||
@@ -36,6 +37,7 @@ class MistralToolCall(ToolCall):
|
||||
return "".join(choices(ALPHANUMERIC, k=9))
|
||||
|
||||
|
||||
@ToolParserManager.register_module("mistral")
|
||||
class MistralToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Mistral 7B Instruct v0.3, intended for use with the
|
||||
@@ -47,9 +49,7 @@ class MistralToolParser(ToolParser):
|
||||
def __init__(self, tokenizer: AnyTokenizer):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
self.model_tokenizer = self.model_tokenizer.tokenizer
|
||||
else:
|
||||
if not isinstance(self.model_tokenizer, MistralTokenizer):
|
||||
logger.info("Non-Mistral tokenizer detected when using a Mistral "
|
||||
"model...")
|
||||
|
||||
@@ -61,11 +61,14 @@ class MistralToolParser(ToolParser):
|
||||
self.streamed_args_for_tool: List[str] = [
|
||||
] # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = "[TOOL_CALLS]"
|
||||
self.bot_token_id = self.model_tokenizer.vocab[self.bot_token]
|
||||
self.bot_token_id = self.model_tokenizer.get_vocab()[self.bot_token]
|
||||
self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL)
|
||||
|
||||
def extract_tool_calls(self,
|
||||
model_output: str) -> ExtractedToolCallInformation:
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
"""
|
||||
Extract the tool calls from a complete model response. Requires
|
||||
find-and-replacing single quotes with double quotes for JSON parsing,
|
||||
@@ -119,6 +122,7 @@ class MistralToolParser(ToolParser):
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> Union[DeltaMessage, None]:
|
||||
|
||||
# if the tool call token is not in the tokens generated so far, append
|
||||
|
||||
Reference in New Issue
Block a user