379 lines
14 KiB
Python
379 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from types import NoneType
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from tests.tool_parsers.utils import run_tool_extraction
|
|
from vllm.tokenizers import TokenizerLike
|
|
from vllm.tool_parsers import ToolParserManager
|
|
|
|
|
|
@dataclass
|
|
class ToolParserTestConfig:
|
|
"""Configuration for a tool parser's common tests.
|
|
|
|
This dataclass contains all the test data and expected results needed
|
|
to run the common test suite for a parser. Each parser test file
|
|
creates one instance of this config with parser-specific values.
|
|
|
|
Attributes:
|
|
parser_name: Name used with ToolParserManager (e.g., "mistral")
|
|
|
|
Test data (model outputs):
|
|
no_tool_calls_output: Plain text without any tool syntax
|
|
single_tool_call_output: One tool call with simple arguments
|
|
parallel_tool_calls_output: Multiple tool calls in one response
|
|
various_data_types_output: Tool with various data types
|
|
empty_arguments_output: Tool call with no parameters
|
|
surrounding_text_output: Tool call mixed with regular text
|
|
escaped_strings_output: Tool call with escaped chars
|
|
malformed_input_outputs: List of invalid inputs
|
|
|
|
Expected results:
|
|
single_tool_call_expected_name: Expected function name
|
|
single_tool_call_expected_args: Expected arguments dict
|
|
parallel_tool_calls_count: Number of tools in parallel test
|
|
parallel_tool_calls_names: Function names in order
|
|
single_tool_call_expected_content: Content field when tool called
|
|
parallel_tool_calls_expected_content: Content for parallel test
|
|
|
|
xfail markers:
|
|
xfail_streaming: Mapping test name to xfail reason (streaming only)
|
|
xfail_nonstreaming: Mapping test name to xfail reason (non-streaming)
|
|
|
|
Special flags:
|
|
allow_empty_or_json_empty_args: True if "" or "{}" both valid for empty args
|
|
supports_typed_arguments: True if the parser supports typed function arguments
|
|
"""
|
|
|
|
# Parser identification
|
|
parser_name: str
|
|
|
|
# Test data - model outputs for each common test
|
|
no_tool_calls_output: str
|
|
single_tool_call_output: str
|
|
parallel_tool_calls_output: str
|
|
various_data_types_output: str
|
|
empty_arguments_output: str
|
|
surrounding_text_output: str
|
|
escaped_strings_output: str
|
|
malformed_input_outputs: list[str]
|
|
|
|
# Expected results for specific tests (optional overrides)
|
|
single_tool_call_expected_name: str = "get_weather"
|
|
single_tool_call_expected_args: dict[str, Any] = field(
|
|
default_factory=lambda: {"city": "Tokyo"}
|
|
)
|
|
parallel_tool_calls_count: int = 2
|
|
parallel_tool_calls_names: list[str] = field(
|
|
default_factory=lambda: ["get_weather", "get_time"]
|
|
)
|
|
|
|
# xfail configuration - maps test name to xfail reason
|
|
xfail_streaming: dict[str, str] = field(default_factory=dict)
|
|
xfail_nonstreaming: dict[str, str] = field(default_factory=dict)
|
|
|
|
# Content expectations (some parsers strip content, others don't)
|
|
single_tool_call_expected_content: str | None = None
|
|
parallel_tool_calls_expected_content: str | None = None
|
|
|
|
# Special assertions for edge cases
|
|
allow_empty_or_json_empty_args: bool = True # "{}" or "" for empty args
|
|
supports_typed_arguments: bool = True
|
|
|
|
|
|
class ToolParserTests:
|
|
"""Mixin class providing common test suite for tool parsers.
|
|
|
|
To use this mixin in a parser test file:
|
|
|
|
1. Create a test_config fixture that returns a ToolParserTestConfig instance
|
|
2. Inherit from this class
|
|
3. Add parser-specific tests as additional methods
|
|
|
|
Example:
|
|
class TestMistralToolParser(ToolParserTests):
|
|
@pytest.fixture
|
|
def test_config(self) -> ToolParserTestConfig:
|
|
return ToolParserTestConfig(
|
|
parser_name="mistral",
|
|
no_tool_calls_output="Plain text...",
|
|
# ... other config ...
|
|
)
|
|
|
|
# Parser-specific tests
|
|
def test_mistral_specific_feature(self, tool_parser):
|
|
# Custom test logic
|
|
pass
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def test_config(self) -> ToolParserTestConfig:
|
|
"""Override this to provide parser-specific configuration."""
|
|
raise NotImplementedError(
|
|
"Subclass must provide test_config fixture returning ToolParserTestConfig"
|
|
)
|
|
|
|
@pytest.fixture
|
|
def tokenizer(self, default_tokenizer: TokenizerLike) -> TokenizerLike:
|
|
"""Override this to provide parser-specific tokenizer."""
|
|
return default_tokenizer
|
|
|
|
@pytest.fixture
|
|
def tool_parser(self, test_config: ToolParserTestConfig, tokenizer: TokenizerLike):
|
|
return ToolParserManager.get_tool_parser(test_config.parser_name)(tokenizer)
|
|
|
|
@pytest.fixture(params=[True, False])
|
|
def streaming(self, request: pytest.FixtureRequest) -> bool:
|
|
return request.param
|
|
|
|
def test_no_tool_calls(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser handles plain text without tool syntax."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_no_tool_calls"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser, test_config.no_tool_calls_output, streaming=streaming
|
|
)
|
|
assert content == test_config.no_tool_calls_output, (
|
|
f"Expected content to match input, got {content}"
|
|
)
|
|
assert len(tool_calls) == 0, f"Expected no tool calls, got {len(tool_calls)}"
|
|
|
|
def test_single_tool_call_simple_args(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser extracts one tool with simple arguments."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_single_tool_call_simple_args"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser, test_config.single_tool_call_output, streaming=streaming
|
|
)
|
|
|
|
# Content check (some parsers strip it)
|
|
if test_config.single_tool_call_expected_content is not None:
|
|
assert content == test_config.single_tool_call_expected_content
|
|
|
|
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
|
assert tool_calls[0].type == "function"
|
|
assert tool_calls[0].function.name == test_config.single_tool_call_expected_name
|
|
|
|
args = json.loads(tool_calls[0].function.arguments)
|
|
for key, value in test_config.single_tool_call_expected_args.items():
|
|
assert args.get(key) == value, (
|
|
f"Expected {key}={value}, got {args.get(key)}"
|
|
)
|
|
|
|
def test_parallel_tool_calls(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser handles multiple tools in one response."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_parallel_tool_calls"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser,
|
|
test_config.parallel_tool_calls_output,
|
|
streaming=streaming,
|
|
)
|
|
|
|
assert len(tool_calls) == test_config.parallel_tool_calls_count, (
|
|
f"Expected {test_config.parallel_tool_calls_count} "
|
|
f"tool calls, got {len(tool_calls)}"
|
|
)
|
|
|
|
# Verify tool names match expected
|
|
for i, expected_name in enumerate(test_config.parallel_tool_calls_names):
|
|
assert tool_calls[i].type == "function"
|
|
assert tool_calls[i].function.name == expected_name
|
|
|
|
# Verify unique IDs
|
|
ids = [tc.id for tc in tool_calls]
|
|
assert len(ids) == len(set(ids)), "Tool call IDs should be unique"
|
|
|
|
def test_various_data_types(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser handles all JSON types in arguments."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_various_data_types"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser,
|
|
test_config.various_data_types_output,
|
|
streaming=streaming,
|
|
)
|
|
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
|
|
|
args = json.loads(tool_calls[0].function.arguments)
|
|
# Verify all expected fields present
|
|
required_fields_types = {
|
|
"string_field": str,
|
|
"int_field": int,
|
|
"float_field": float,
|
|
"bool_field": bool,
|
|
"null_field": NoneType,
|
|
"array_field": list,
|
|
"object_field": dict,
|
|
}
|
|
for required_field, expected_type in required_fields_types.items():
|
|
assert required_field in args, (
|
|
f"Expected field '{required_field}' in arguments"
|
|
)
|
|
if test_config.supports_typed_arguments:
|
|
found_type = type(args[required_field])
|
|
assert found_type is expected_type, (
|
|
f"Expected field '{required_field}' to have type {expected_type}, "
|
|
f"got {found_type}"
|
|
)
|
|
|
|
def test_empty_arguments(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser handles parameterless tool calls."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_empty_arguments"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser, test_config.empty_arguments_output, streaming=streaming
|
|
)
|
|
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
|
|
|
args = tool_calls[0].function.arguments
|
|
if test_config.allow_empty_or_json_empty_args:
|
|
assert args in ["{}", ""], f"Expected empty args, got {args}"
|
|
else:
|
|
assert args == "{}", f"Expected {{}}, got {args}"
|
|
|
|
def test_surrounding_text(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser extracts tools from mixed content."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_surrounding_text"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser, test_config.surrounding_text_output, streaming=streaming
|
|
)
|
|
assert len(tool_calls) >= 1, (
|
|
f"Expected at least 1 tool call, got {len(tool_calls)}"
|
|
)
|
|
|
|
def test_escaped_strings(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser handles escaped characters in arguments."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_escaped_strings"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser, test_config.escaped_strings_output, streaming=streaming
|
|
)
|
|
assert len(tool_calls) == 1, f"Expected 1 tool call, got {len(tool_calls)}"
|
|
|
|
args = json.loads(tool_calls[0].function.arguments)
|
|
# At minimum, verify we can parse and have expected fields
|
|
# Exact escaping behavior varies by parser
|
|
assert len(args) > 0, "Expected some arguments with escaped strings"
|
|
|
|
def test_malformed_input(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
streaming: bool,
|
|
):
|
|
"""Verify parser gracefully handles invalid syntax."""
|
|
# Apply xfail markers if configured
|
|
test_name = "test_malformed_input"
|
|
self.apply_xfail_mark(request, test_config, test_name, streaming)
|
|
|
|
for malformed_input in test_config.malformed_input_outputs:
|
|
# Should not raise exception
|
|
content, tool_calls = run_tool_extraction(
|
|
tool_parser, malformed_input, streaming=streaming
|
|
)
|
|
# Parser should handle gracefully (exact behavior varies)
|
|
|
|
def test_streaming_reconstruction(
|
|
self,
|
|
request: pytest.FixtureRequest,
|
|
tool_parser: Any,
|
|
test_config: ToolParserTestConfig,
|
|
):
|
|
"""Verify streaming produces same result as non-streaming."""
|
|
test_name = "test_streaming_reconstruction"
|
|
self.apply_xfail_mark(request, test_config, test_name, True)
|
|
|
|
test_output = test_config.single_tool_call_output
|
|
|
|
# Non-streaming result
|
|
content_non, tools_non = run_tool_extraction(
|
|
tool_parser, test_output, streaming=False
|
|
)
|
|
|
|
# Streaming result
|
|
content_stream, tools_stream = run_tool_extraction(
|
|
tool_parser, test_output, streaming=True
|
|
)
|
|
|
|
# Compare results
|
|
assert content_non == content_stream, "Content should match between modes"
|
|
assert len(tools_non) == len(tools_stream), "Tool count should match"
|
|
if len(tools_non) > 0:
|
|
assert tools_non[0].function.name == tools_stream[0].function.name
|
|
assert tools_non[0].function.arguments == tools_stream[0].function.arguments
|
|
|
|
def apply_xfail_mark(self, request, test_config, test_name, streaming):
|
|
reason = None
|
|
if streaming and test_name in test_config.xfail_streaming:
|
|
reason = test_config.xfail_streaming[test_name]
|
|
elif not streaming and test_name in test_config.xfail_nonstreaming:
|
|
reason = test_config.xfail_nonstreaming[test_name]
|
|
if reason is not None:
|
|
mark = pytest.mark.xfail(reason=reason, strict=True)
|
|
request.node.add_marker(mark)
|