[Frontend][Feature] support tool calling for internlm/internlm2_5-7b-chat model (#8405)

This commit is contained in:
代君
2024-10-04 10:36:39 +08:00
committed by GitHub
parent 2838d6b38e
commit 3dbb215b38
13 changed files with 533 additions and 46 deletions

View File

@@ -29,10 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
OpenAIServing,
PromptAdapterPath,
TextTokensPrompt)
from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser,
Llama3JsonToolParser,
MistralToolParser,
ToolParser)
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
@@ -82,15 +79,13 @@ class OpenAIServingChat(OpenAIServing):
self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None
if self.enable_auto_tools:
if tool_parser == "mistral":
self.tool_parser = MistralToolParser
elif tool_parser == "hermes":
self.tool_parser = Hermes2ProToolParser
elif tool_parser == "llama3_json":
self.tool_parser = Llama3JsonToolParser
else:
try:
self.tool_parser = ToolParserManager.get_tool_parser(
tool_parser)
except Exception as e:
raise TypeError("Error: --enable-auto-tool-choice requires "
"--tool-call-parser")
f"tool_parser:'{tool_parser}' which has not "
"been registered") from e
async def create_chat_completion(
self,
@@ -187,6 +182,10 @@ class OpenAIServingChat(OpenAIServing):
raw_request.state.request_metadata = request_metadata
try:
if self.enable_auto_tools and self.tool_parser:
request = self.tool_parser(tokenizer).adjust_request(
request=request)
if isinstance(prompt, str):
prompt_inputs = self._tokenize_prompt_input(
request,
@@ -282,11 +281,11 @@ class OpenAIServingChat(OpenAIServing):
num_choices = 1 if request.n is None else request.n
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
num_prompt_tokens = 0
tool_parser: Optional[ToolParser] = self.tool_parser(
tokenizer) if self.tool_parser else None
tool_parsers: List[Optional[ToolParser]] = [
self.tool_parser(tokenizer) if self.tool_parser else None
] * num_choices
if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam):
tool_choice_function_name = request.tool_choice.function.name
@@ -324,7 +323,7 @@ class OpenAIServingChat(OpenAIServing):
# NOTE num_choices defaults to 1 so this usually executes
# once per request
for i in range(num_choices):
tool_parser = tool_parsers[i]
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
@@ -399,6 +398,7 @@ class OpenAIServingChat(OpenAIServing):
for output in res.outputs:
i = output.index
tool_parser = tool_parsers[i]
if finish_reason_sent[i]:
continue
@@ -446,7 +446,8 @@ class OpenAIServingChat(OpenAIServing):
delta_text=delta_text,
previous_token_ids=previous_token_ids,
current_token_ids=current_token_ids,
delta_token_ids=output.token_ids))
delta_token_ids=output.token_ids,
request=request))
# update the previous values for the next iteration
previous_texts[i] = current_text
@@ -685,7 +686,8 @@ class OpenAIServingChat(OpenAIServing):
and self.tool_parser:
tool_parser = self.tool_parser(tokenizer)
tool_call_info = tool_parser.extract_tool_calls(output.text)
tool_call_info = tool_parser.extract_tool_calls(
output.text, request=request)
tools_called = tool_call_info.tools_called
if tool_call_info.tools_called:
message = ChatMessage(role=role,