Add gigachat 3.1 tool parser + fix gigachat3 tool parser (#36664)
Signed-off-by: Viacheslav Barinov <viacheslav.teh@gmail.com>
This commit is contained in:
@@ -13,6 +13,13 @@ from vllm.entrypoints.openai.engine.protocol import FunctionCall
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers import ToolParser, ToolParserManager
|
||||
|
||||
MSG_SEP_TOKEN = "<|message_sep|>\n\n"
|
||||
ROLE_SEP_TOKEN = "<|role_sep|>\n"
|
||||
EOS_TOKEN = "</s>"
|
||||
TOOL_HEADER_GIGACHAT3 = f"function call{ROLE_SEP_TOKEN}"
|
||||
TOOL_HEADER_GIGACHAT31 = "<|function_call|>"
|
||||
|
||||
|
||||
SIMPLE_ARGS_DICT = {
|
||||
"action": "create",
|
||||
"id": "preferences",
|
||||
@@ -24,7 +31,10 @@ SIMPLE_FUNCTION_JSON = json.dumps(
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
SIMPLE_FUNCTION_OUTPUT = "function call" + SIMPLE_FUNCTION_JSON
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3 = (
|
||||
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{SIMPLE_FUNCTION_JSON}"
|
||||
)
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{SIMPLE_FUNCTION_JSON}"
|
||||
SIMPLE_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(SIMPLE_ARGS_DICT, ensure_ascii=False),
|
||||
@@ -38,7 +48,12 @@ PARAMETERLESS_FUNCTION_JSON = json.dumps(
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT = "function call" + PARAMETERLESS_FUNCTION_JSON
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3 = (
|
||||
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{PARAMETERLESS_FUNCTION_JSON}"
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31 = (
|
||||
f"{TOOL_HEADER_GIGACHAT31}{PARAMETERLESS_FUNCTION_JSON}"
|
||||
)
|
||||
PARAMETERLESS_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps({}, ensure_ascii=False),
|
||||
@@ -62,17 +77,38 @@ COMPLEX_FUNCTION_JSON = json.dumps(
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
COMPLEX_FUNCTION_OUTPUT = "function call" + COMPLEX_FUNCTION_JSON
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3 = (
|
||||
f"{MSG_SEP_TOKEN}{TOOL_HEADER_GIGACHAT3}{COMPLEX_FUNCTION_JSON}"
|
||||
)
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31 = f"{TOOL_HEADER_GIGACHAT31}{COMPLEX_FUNCTION_JSON}"
|
||||
COMPLEX_FUNCTION_CALL = FunctionCall(
|
||||
name="manage_user_memory",
|
||||
arguments=json.dumps(COMPLEX_ARGS_DICT, ensure_ascii=False),
|
||||
)
|
||||
|
||||
|
||||
CONTENT_TEXT = "I'll check that for you."
|
||||
MIXED_OUTPUT_GIGACHAT3 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT3}"
|
||||
MIXED_OUTPUT_GIGACHAT31 = f"{CONTENT_TEXT}{SIMPLE_FUNCTION_OUTPUT_GIGACHAT31}"
|
||||
|
||||
|
||||
@pytest.fixture(name="gigachat_tokenizer")
|
||||
def fixture_gigachat_tokenizer(default_tokenizer: TokenizerLike):
|
||||
default_tokenizer.add_tokens(
|
||||
[
|
||||
MSG_SEP_TOKEN,
|
||||
ROLE_SEP_TOKEN,
|
||||
TOOL_HEADER_GIGACHAT31,
|
||||
EOS_TOKEN,
|
||||
]
|
||||
)
|
||||
return default_tokenizer
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming", [True, False])
|
||||
def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
def test_no_tool_call(streaming: bool, gigachat_tokenizer: TokenizerLike):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
gigachat_tokenizer
|
||||
)
|
||||
model_output = "How can I help you today?"
|
||||
content, tool_calls = run_tool_extraction(
|
||||
@@ -85,45 +121,143 @@ def test_no_tool_call(streaming: bool, default_tokenizer: TokenizerLike):
|
||||
TEST_CASES = [
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_FUNCTION_OUTPUT,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_streaming",
|
||||
id="simple_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_FUNCTION_OUTPUT,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_nonstreaming",
|
||||
id="simple_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_streaming",
|
||||
id="parameterless_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_nonstreaming",
|
||||
id="parameterless_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLEX_FUNCTION_OUTPUT,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_streaming",
|
||||
id="complex_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLEX_FUNCTION_OUTPUT,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT3,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_nonstreaming",
|
||||
id="complex_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT3,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_with_eos_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT3 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_with_eos_gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
SIMPLE_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
None,
|
||||
id="simple_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
PARAMETERLESS_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[PARAMETERLESS_FUNCTION_CALL],
|
||||
None,
|
||||
id="parameterless_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
COMPLEX_FUNCTION_OUTPUT_GIGACHAT31,
|
||||
[COMPLEX_FUNCTION_CALL],
|
||||
None,
|
||||
id="complex_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT31,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
True,
|
||||
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_streaming_with_eos_gigachat31",
|
||||
),
|
||||
pytest.param(
|
||||
False,
|
||||
MIXED_OUTPUT_GIGACHAT31 + EOS_TOKEN,
|
||||
[SIMPLE_FUNCTION_CALL],
|
||||
CONTENT_TEXT,
|
||||
id="mixed_content_nonstreaming_with_eos_gigachat31",
|
||||
),
|
||||
]
|
||||
|
||||
@@ -136,14 +270,16 @@ def test_tool_call(
|
||||
model_output: str,
|
||||
expected_tool_calls: list[FunctionCall],
|
||||
expected_content: str | None,
|
||||
default_tokenizer: TokenizerLike,
|
||||
gigachat_tokenizer: TokenizerLike,
|
||||
):
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
gigachat_tokenizer
|
||||
)
|
||||
content, tool_calls = run_tool_extraction(
|
||||
tool_parser, model_output, streaming=streaming
|
||||
)
|
||||
if content == "":
|
||||
content = None
|
||||
assert content == expected_content
|
||||
assert len(tool_calls) == len(expected_tool_calls)
|
||||
for actual, expected in zip(tool_calls, expected_tool_calls):
|
||||
@@ -154,15 +290,46 @@ def test_tool_call(
|
||||
assert actual_args == expected_args
|
||||
|
||||
|
||||
def test_streaming_tool_call_with_large_steps(default_tokenizer: TokenizerLike):
|
||||
@pytest.mark.parametrize(
|
||||
"model_output_deltas",
|
||||
[
|
||||
pytest.param(
|
||||
[
|
||||
CONTENT_TEXT[:3],
|
||||
CONTENT_TEXT[3:5],
|
||||
CONTENT_TEXT[5:],
|
||||
MSG_SEP_TOKEN,
|
||||
TOOL_HEADER_GIGACHAT3,
|
||||
COMPLEX_FUNCTION_JSON[:40],
|
||||
COMPLEX_FUNCTION_JSON[40:-1],
|
||||
COMPLEX_FUNCTION_JSON[-1],
|
||||
],
|
||||
id="gigachat3",
|
||||
),
|
||||
pytest.param(
|
||||
[
|
||||
CONTENT_TEXT[:3],
|
||||
CONTENT_TEXT[3:5],
|
||||
CONTENT_TEXT[5:],
|
||||
TOOL_HEADER_GIGACHAT31,
|
||||
COMPLEX_FUNCTION_JSON[:40],
|
||||
COMPLEX_FUNCTION_JSON[40:-1],
|
||||
COMPLEX_FUNCTION_JSON[-1],
|
||||
],
|
||||
id="gigachat31",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_streaming_tool_call_with_large_steps(
|
||||
model_output_deltas: list[str],
|
||||
gigachat_tokenizer: TokenizerLike,
|
||||
):
|
||||
"""
|
||||
Test that the closing braces are streamed correctly.
|
||||
"""
|
||||
tool_parser: ToolParser = ToolParserManager.get_tool_parser("gigachat3")(
|
||||
default_tokenizer
|
||||
gigachat_tokenizer
|
||||
)
|
||||
model_output_deltas = [
|
||||
"function call",
|
||||
COMPLEX_FUNCTION_JSON[:40],
|
||||
COMPLEX_FUNCTION_JSON[40:],
|
||||
]
|
||||
reconstructor = run_tool_extraction_streaming(
|
||||
tool_parser,
|
||||
model_output_deltas,
|
||||
|
||||
@@ -25,7 +25,12 @@ from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
logger = init_logger(__name__)
|
||||
|
||||
REGEX_FUNCTION_CALL = re.compile(
|
||||
r"function call(?:<\|role_sep\|>\n)?(\{.*)",
|
||||
r"(?:function call<\|role_sep\|>\n|<\|function_call\|>)(.*)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
REGEX_CONTENT_PATTERN = re.compile(
|
||||
r"^(.*?)(?:<\|message_sep\|>|<\|function_call\|>)",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
@@ -47,57 +52,67 @@ class GigaChat3ToolParser(ToolParser):
|
||||
self.tool_name_sent: bool = False
|
||||
self.tool_id: str | None = None
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.content_buffer: str = ""
|
||||
self.trigger_start = "function call{"
|
||||
self.end_content: bool = False
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
|
||||
request = super().adjust_request(request)
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
match = REGEX_FUNCTION_CALL.search(model_output)
|
||||
if not match:
|
||||
function_call = None
|
||||
content = None
|
||||
if model_output.rstrip().endswith("</s>"):
|
||||
model_output = model_output[: model_output.rfind("</s>")]
|
||||
m_func = REGEX_FUNCTION_CALL.search(model_output)
|
||||
if m_func:
|
||||
try:
|
||||
function_call = json.loads(m_func.group(1), strict=False)
|
||||
if (
|
||||
isinstance(function_call, dict)
|
||||
and "name" in function_call
|
||||
and "arguments" in function_call
|
||||
):
|
||||
if not isinstance(function_call["arguments"], dict):
|
||||
function_call = None
|
||||
else:
|
||||
function_call = None
|
||||
except json.JSONDecodeError:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
m_content = REGEX_CONTENT_PATTERN.search(model_output)
|
||||
content = m_content.group(1) if m_content else model_output
|
||||
if not function_call:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
content=content if content else None,
|
||||
)
|
||||
json_candidate = match.group(1).strip()
|
||||
try:
|
||||
data = json.loads(json_candidate)
|
||||
except json.JSONDecodeError:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
if not (isinstance(data, dict) and "name" in data and "arguments" in data):
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False,
|
||||
tool_calls=[],
|
||||
content=model_output,
|
||||
)
|
||||
name = data["name"]
|
||||
args = data["arguments"]
|
||||
name = function_call["name"]
|
||||
args = function_call["arguments"]
|
||||
if not isinstance(args, str):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=args,
|
||||
),
|
||||
)
|
||||
]
|
||||
prefix = model_output[: match.start()]
|
||||
content = prefix.rstrip() if prefix and prefix.strip() else None
|
||||
|
||||
args = json.dumps(function_call["arguments"], ensure_ascii=False)
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=name,
|
||||
arguments=args,
|
||||
),
|
||||
)
|
||||
],
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
@@ -110,39 +125,37 @@ class GigaChat3ToolParser(ToolParser):
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
content = None
|
||||
func_name = None
|
||||
cur_args = None
|
||||
m_func = REGEX_FUNCTION_CALL.search(current_text)
|
||||
if not self.tool_started:
|
||||
match = REGEX_FUNCTION_CALL.search(current_text)
|
||||
if match:
|
||||
self.tool_started = True
|
||||
self.content_buffer = ""
|
||||
m_content = REGEX_CONTENT_PATTERN.search(delta_text)
|
||||
if m_content:
|
||||
content = m_content.group(1)
|
||||
self.end_content = True
|
||||
else:
|
||||
self.content_buffer += delta_text
|
||||
clean_buffer = self.content_buffer.lstrip()
|
||||
is_prefix = self.trigger_start.startswith(clean_buffer)
|
||||
starts_with_trigger = clean_buffer.startswith(self.trigger_start)
|
||||
if is_prefix or starts_with_trigger:
|
||||
return None
|
||||
else:
|
||||
flush_text = self.content_buffer
|
||||
self.content_buffer = ""
|
||||
return DeltaMessage(content=flush_text)
|
||||
|
||||
match = REGEX_FUNCTION_CALL.search(current_text)
|
||||
if not match:
|
||||
if not self.end_content:
|
||||
content = delta_text
|
||||
if m_func:
|
||||
self.tool_started = True
|
||||
if content:
|
||||
return DeltaMessage(content=content)
|
||||
if not m_func:
|
||||
return None
|
||||
json_tail = match.group(1).strip()
|
||||
json_tail = m_func.group(1).strip()
|
||||
name_match = NAME_REGEX.search(json_tail)
|
||||
if name_match:
|
||||
func_name = name_match.group(1)
|
||||
args_match = ARGS_REGEX.search(json_tail)
|
||||
if args_match:
|
||||
cur_args = args_match.group(1).strip()
|
||||
if cur_args.endswith("</s>"):
|
||||
cur_args = cur_args[: -len("</s>")]
|
||||
if cur_args.endswith("}"): # last '}' end of json
|
||||
try:
|
||||
candidate = cur_args[:-1].strip()
|
||||
json.loads(candidate)
|
||||
json.loads(candidate, strict=False)
|
||||
cur_args = candidate
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
@@ -165,11 +178,10 @@ class GigaChat3ToolParser(ToolParser):
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
if cur_args is None:
|
||||
return None
|
||||
prev_args = self.prev_tool_call_arr[0].get("arguments", "")
|
||||
prev_args = self.prev_tool_call_arr[0].get("arguments_str", "")
|
||||
if not prev_args:
|
||||
delta_args = cur_args
|
||||
elif cur_args.startswith(prev_args):
|
||||
@@ -178,7 +190,15 @@ class GigaChat3ToolParser(ToolParser):
|
||||
return None
|
||||
if not delta_args:
|
||||
return None
|
||||
self.prev_tool_call_arr[0]["arguments"] = cur_args
|
||||
self.prev_tool_call_arr[0]["arguments_str"] = cur_args
|
||||
try:
|
||||
args_dict = json.loads(cur_args, strict=False)
|
||||
self.prev_tool_call_arr[0]["arguments"] = args_dict
|
||||
except json.JSONDecodeError:
|
||||
self.prev_tool_call_arr[0]["arguments"] = {}
|
||||
if len(self.streamed_args_for_tool) <= 0:
|
||||
self.streamed_args_for_tool.append("")
|
||||
self.streamed_args_for_tool[0] = cur_args
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
@@ -188,5 +208,4 @@ class GigaChat3ToolParser(ToolParser):
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user