feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Luciano Martins <lucianomartins@google.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
(cherry picked from commit 08ed2b9688)
This commit is contained in:
@@ -394,6 +394,22 @@ VLM_TEST_SETTINGS = {
|
||||
vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}},
|
||||
patch_hf_runner=model_utils.gemma3_patch_hf_runner,
|
||||
),
|
||||
"gemma4": VLMTestInfo(
|
||||
models=["google/gemma-4-E2B-it"],
|
||||
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
|
||||
prompt_formatter=lambda img_prompt: f"<bos><start_of_turn>user\n{img_prompt}<end_of_turn>\n<start_of_turn>model\n", # noqa: E501
|
||||
single_image_prompts=IMAGE_ASSETS.prompts(
|
||||
{
|
||||
"stop_sign": "What's the content in the center of the image?",
|
||||
"cherry_blossom": "What is the season?",
|
||||
}
|
||||
),
|
||||
multi_image_prompt="Describe the two images in detail.",
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
auto_cls=AutoModelForImageTextToText,
|
||||
vllm_runner_kwargs={"limit_mm_per_prompt": {"image": 4}},
|
||||
),
|
||||
"granite_vision": VLMTestInfo(
|
||||
models=["ibm-granite/granite-vision-3.3-2b"],
|
||||
test_type=(VLMTestType.IMAGE),
|
||||
|
||||
44
tests/models/multimodal/processing/test_gemma4.py
Normal file
44
tests/models/multimodal/processing/test_gemma4.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
# TODO: to be updated to "google/gemma-4-e2b-it" once the models are available
|
||||
GEMMA4_MODEL_ID = "google/gemma-4-E2B-it"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID])
|
||||
def test_limit_mm_per_prompt(
|
||||
image_assets: ImageTestAssets,
|
||||
model_id: str,
|
||||
):
|
||||
"""Test that limit_mm_per_prompt accurately restricts multiple images."""
|
||||
# We only allow 1 image
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
mm_processor_kwargs={},
|
||||
limit_mm_per_prompt={"image": 1},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
|
||||
# Provide 2 images in the prompt
|
||||
prompt = "<image><image>"
|
||||
# image_assets usually has multiple images
|
||||
images = [asset.pil_image for asset in image_assets][:2]
|
||||
if len(images) < 2:
|
||||
images = [images[0], images[0]]
|
||||
|
||||
mm_data = {"image": images}
|
||||
|
||||
# Expect ValueError when exceeding limit
|
||||
with pytest.raises(ValueError, match="At most 1 image"):
|
||||
processor(
|
||||
prompt,
|
||||
mm_items=processor.info.parse_mm_data(mm_data),
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
@@ -277,6 +277,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||
"google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"}
|
||||
),
|
||||
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
|
||||
"Gemma4ForCausalLM": _HfExamplesInfo(
|
||||
"google/gemma-4-E2B-it",
|
||||
min_transformers_version="5.0.0",
|
||||
),
|
||||
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
|
||||
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
|
||||
@@ -805,6 +809,10 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||
),
|
||||
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
|
||||
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
|
||||
"Gemma4ForConditionalGeneration": _HfExamplesInfo(
|
||||
"google/gemma-4-E2B-it",
|
||||
min_transformers_version="5.5.0",
|
||||
),
|
||||
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"),
|
||||
"GlmAsrForConditionalGeneration": _HfExamplesInfo(
|
||||
"zai-org/GLM-ASR-Nano-2512",
|
||||
|
||||
196
tests/reasoning/test_gemma4_reasoning_parser.py
Normal file
196
tests/reasoning/test_gemma4_reasoning_parser.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.reasoning.utils import run_reasoning_extraction
|
||||
from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
||||
|
||||
# Using mistral tokenizer as a generic mock since the actual model is not on HF
|
||||
from vllm.tokenizers.registry import get_tokenizer
|
||||
|
||||
parser_name = "gemma4"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def generic_tokenizer():
|
||||
return get_tokenizer("google/gemma-4-E2B-it")
|
||||
|
||||
|
||||
INVALID_SIMPLE_NONSTREAMING = {
|
||||
"output": "This is a reasoning section<channel|>This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
INVALID_SIMPLE_STREAMING = {
|
||||
"output": "This is a reasoning section<channel|>This is the rest",
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning sectionThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
INVALID_COMPLETE_NONSTREAMING = {
|
||||
"output": "This is a reasoning section<channel|>",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
INVALID_COMPLETE_STREAMING = {
|
||||
"output": "This is a reasoning section<channel|>",
|
||||
"reasoning": None,
|
||||
"content": "This is a reasoning section",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NO_CONTENT = {
|
||||
"output": "<|channel>This is reasoning",
|
||||
"reasoning": "This is reasoning",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NO_REASONING = {
|
||||
"output": "This is content",
|
||||
"reasoning": None,
|
||||
"content": "This is content",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
REASONING_WITH_CHANNEL = {
|
||||
"output": "<|channel>This is a reasoning section<channel|>This is the rest",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "This is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
COMPLETE_REASONING_WITH_CHANNEL = {
|
||||
"output": "<|channel>This is a reasoning section<channel|>",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
MULTIPLE_LINES_WITH_CHANNEL = {
|
||||
"output": "<|channel>This\nThat<channel|>This is the rest\nThat",
|
||||
"reasoning": "This\nThat",
|
||||
"content": "This is the rest\nThat",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
CHANNEL_NO_END = {
|
||||
"output": "<|channel>This is a reasoning section",
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": None,
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
EMPTY = {
|
||||
"output": "",
|
||||
"reasoning": None,
|
||||
"content": "",
|
||||
"is_reasoning_end": False,
|
||||
}
|
||||
NEW_LINE_NONSTREAMING = {
|
||||
"output": (
|
||||
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
|
||||
),
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
NEW_LINE_STREAMING = {
|
||||
"output": (
|
||||
"Before\n<|channel>This is a reasoning section<channel|>\nThis is the rest"
|
||||
),
|
||||
"reasoning": "This is a reasoning section",
|
||||
"content": "Before\n\nThis is the rest",
|
||||
"is_reasoning_end": True,
|
||||
}
|
||||
|
||||
TEST_CASES = [
|
||||
pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"),
|
||||
pytest.param(True, INVALID_SIMPLE_STREAMING, id="invalid_simple_streaming"),
|
||||
pytest.param(False, INVALID_COMPLETE_NONSTREAMING, id="invalid_complete"),
|
||||
pytest.param(True, INVALID_COMPLETE_STREAMING, id="invalid_complete_streaming"),
|
||||
pytest.param(False, NO_CONTENT, id="no_content"),
|
||||
pytest.param(False, NO_REASONING, id="no_reasoning"),
|
||||
pytest.param(False, REASONING_WITH_CHANNEL, id="reasoning"),
|
||||
pytest.param(True, REASONING_WITH_CHANNEL, id="reasoning_streaming"),
|
||||
pytest.param(False, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning"),
|
||||
pytest.param(
|
||||
True, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning_streaming"
|
||||
),
|
||||
pytest.param(False, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines"),
|
||||
pytest.param(True, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines_streaming"),
|
||||
pytest.param(False, CHANNEL_NO_END, id="no_end"),
|
||||
pytest.param(True, CHANNEL_NO_END, id="no_end_streaming"),
|
||||
pytest.param(False, EMPTY, id="empty"),
|
||||
pytest.param(False, NEW_LINE_NONSTREAMING, id="new_line"),
|
||||
pytest.param(True, NEW_LINE_STREAMING, id="new_line_streaming"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("streaming, param_dict", TEST_CASES)
|
||||
def test_gemma4_reasoning(
|
||||
streaming: bool,
|
||||
param_dict: dict,
|
||||
generic_tokenizer,
|
||||
):
|
||||
output = param_dict["output"]
|
||||
|
||||
# Resolve token IDs dynamically from the real tokenizer
|
||||
vocab = generic_tokenizer.get_vocab()
|
||||
start_token_id = vocab["<|channel>"]
|
||||
end_token_id = vocab["<channel|>"]
|
||||
|
||||
index_start = output.find("<|channel>")
|
||||
len_start = len("<|channel>")
|
||||
index_end = output.find("<channel|>")
|
||||
len_end = len("<channel|>")
|
||||
|
||||
output_tokens = []
|
||||
|
||||
def _encode(text: str) -> list[int]:
|
||||
if not text:
|
||||
return []
|
||||
# Handle both raw transformers and vLLM wrappers
|
||||
enc = getattr(generic_tokenizer, "tokenizer", generic_tokenizer)
|
||||
try:
|
||||
return enc.encode(text, add_special_tokens=False)
|
||||
except TypeError:
|
||||
return enc.encode(text)
|
||||
|
||||
if index_start != -1:
|
||||
output_before = output[:index_start]
|
||||
output_tokens += _encode(output_before)
|
||||
output_tokens += [start_token_id]
|
||||
|
||||
if index_end != -1:
|
||||
output_middle = output[index_start + len_start : index_end]
|
||||
output_after = output[index_end + len_end :]
|
||||
output_tokens += _encode(output_middle)
|
||||
output_tokens += [end_token_id]
|
||||
output_tokens += _encode(output_after)
|
||||
else:
|
||||
output_middle = output[index_start + len_start :]
|
||||
output_tokens += _encode(output_middle)
|
||||
elif index_end != -1:
|
||||
output_before = output[:index_end]
|
||||
output_after = output[index_end + len_end :]
|
||||
output_tokens += _encode(output_before)
|
||||
output_tokens += [end_token_id]
|
||||
output_tokens += _encode(output_after)
|
||||
else:
|
||||
output_tokens += _encode(output)
|
||||
|
||||
parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)(
|
||||
generic_tokenizer
|
||||
)
|
||||
|
||||
# We use the generic run_reasoning_extraction from utils
|
||||
# Use decode per token to get standard spaces instead of
|
||||
# SentencePiece space characters
|
||||
output_token_strings = [generic_tokenizer.decode([t]) for t in output_tokens]
|
||||
reasoning, content = run_reasoning_extraction(
|
||||
parser, output_token_strings, streaming=streaming
|
||||
)
|
||||
|
||||
assert reasoning == param_dict["reasoning"]
|
||||
assert content == param_dict["content"]
|
||||
|
||||
# Test is_reasoning_end
|
||||
is_reasoning_end = parser.is_reasoning_end(output_tokens)
|
||||
assert is_reasoning_end == param_dict["is_reasoning_end"]
|
||||
504
tests/tool_parsers/test_gemma4_tool_parser.py
Normal file
504
tests/tool_parsers/test_gemma4_tool_parser.py
Normal file
@@ -0,0 +1,504 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
|
||||
from vllm.tool_parsers.gemma4_tool_parser import (
|
||||
TOOL_CALL_END,
|
||||
TOOL_CALL_START,
|
||||
Gemma4ToolParser,
|
||||
_parse_gemma4_args,
|
||||
_parse_gemma4_array,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tokenizer():
|
||||
tokenizer = MagicMock()
|
||||
tokenizer.encode.return_value = [1, 2, 3]
|
||||
# Include the tool call start token in the vocab for the parser
|
||||
tokenizer.get_vocab.return_value = {TOOL_CALL_START: 48, TOOL_CALL_END: 49}
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parser(mock_tokenizer):
|
||||
return Gemma4ToolParser(mock_tokenizer)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request():
|
||||
request = MagicMock(spec=ChatCompletionRequest)
|
||||
request.tools = []
|
||||
request.tool_choice = "auto"
|
||||
return request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for _parse_gemma4_args (shared parser logic)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseGemma4Args:
|
||||
def test_empty_string(self):
|
||||
assert _parse_gemma4_args("") == {}
|
||||
|
||||
def test_whitespace_only(self):
|
||||
assert _parse_gemma4_args(" ") == {}
|
||||
|
||||
def test_single_string_value(self):
|
||||
result = _parse_gemma4_args('location:<|"|>Paris<|"|>')
|
||||
assert result == {"location": "Paris"}
|
||||
|
||||
def test_string_value_with_comma(self):
|
||||
result = _parse_gemma4_args('location:<|"|>Paris, France<|"|>')
|
||||
assert result == {"location": "Paris, France"}
|
||||
|
||||
def test_multiple_string_values(self):
|
||||
result = _parse_gemma4_args(
|
||||
'location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>'
|
||||
)
|
||||
assert result == {"location": "San Francisco", "unit": "celsius"}
|
||||
|
||||
def test_integer_value(self):
|
||||
result = _parse_gemma4_args("count:42")
|
||||
assert result == {"count": 42}
|
||||
|
||||
def test_float_value(self):
|
||||
result = _parse_gemma4_args("score:3.14")
|
||||
assert result == {"score": 3.14}
|
||||
|
||||
def test_boolean_true(self):
|
||||
result = _parse_gemma4_args("flag:true")
|
||||
assert result == {"flag": True}
|
||||
|
||||
def test_boolean_false(self):
|
||||
result = _parse_gemma4_args("flag:false")
|
||||
assert result == {"flag": False}
|
||||
|
||||
def test_mixed_types(self):
|
||||
result = _parse_gemma4_args(
|
||||
'name:<|"|>test<|"|>,count:42,active:true,score:3.14'
|
||||
)
|
||||
assert result == {
|
||||
"name": "test",
|
||||
"count": 42,
|
||||
"active": True,
|
||||
"score": 3.14,
|
||||
}
|
||||
|
||||
def test_nested_object(self):
|
||||
result = _parse_gemma4_args('nested:{inner:<|"|>value<|"|>}')
|
||||
assert result == {"nested": {"inner": "value"}}
|
||||
|
||||
def test_array_of_strings(self):
|
||||
result = _parse_gemma4_args('items:[<|"|>a<|"|>,<|"|>b<|"|>]')
|
||||
assert result == {"items": ["a", "b"]}
|
||||
|
||||
def test_unterminated_string(self):
|
||||
"""Unterminated strings should take everything after the delimiter."""
|
||||
result = _parse_gemma4_args('key:<|"|>unterminated')
|
||||
assert result == {"key": "unterminated"}
|
||||
|
||||
def test_empty_value(self):
|
||||
"""Key with no value after colon."""
|
||||
result = _parse_gemma4_args("key:")
|
||||
assert result == {"key": ""}
|
||||
|
||||
|
||||
class TestParseGemma4Array:
|
||||
def test_string_array(self):
|
||||
result = _parse_gemma4_array('<|"|>a<|"|>,<|"|>b<|"|>')
|
||||
assert result == ["a", "b"]
|
||||
|
||||
def test_empty_array(self):
|
||||
result = _parse_gemma4_array("")
|
||||
assert result == []
|
||||
|
||||
def test_bare_values(self):
|
||||
result = _parse_gemma4_array("42,true,3.14")
|
||||
assert result == [42, True, 3.14]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-streaming extraction tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_no_tool_calls(self, parser, mock_request):
|
||||
model_output = "Hello, how can I help you today?"
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is False
|
||||
assert result.tool_calls == []
|
||||
assert result.content == model_output
|
||||
|
||||
def test_single_tool_call(self, parser, mock_request):
|
||||
model_output = (
|
||||
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}<tool_call|>'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"location": "London"}
|
||||
|
||||
def test_multiple_arguments(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<|tool_call>call:get_weather{"
|
||||
'location:<|"|>San Francisco<|"|>,'
|
||||
'unit:<|"|>celsius<|"|>}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"location": "San Francisco", "unit": "celsius"}
|
||||
|
||||
def test_text_before_tool_call(self, parser, mock_request):
|
||||
model_output = (
|
||||
"Let me check the weather for you. "
|
||||
'<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.content == "Let me check the weather for you."
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
|
||||
def test_multiple_tool_calls(self, parser, mock_request):
|
||||
model_output = (
|
||||
'<|tool_call>call:get_weather{location:<|"|>London<|"|>}'
|
||||
"<tool_call|>"
|
||||
'<|tool_call>call:get_time{location:<|"|>London<|"|>}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 2
|
||||
assert result.tool_calls[0].function.name == "get_weather"
|
||||
assert result.tool_calls[1].function.name == "get_time"
|
||||
|
||||
def test_nested_arguments(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<|tool_call>call:complex_function{"
|
||||
'nested:{inner:<|"|>value<|"|>},'
|
||||
'list:[<|"|>a<|"|>,<|"|>b<|"|>]}'
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "complex_function"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"nested": {"inner": "value"}, "list": ["a", "b"]}
|
||||
|
||||
def test_tool_call_with_number_and_boolean(self, parser, mock_request):
|
||||
model_output = (
|
||||
"<|tool_call>call:set_status{"
|
||||
"is_active:true,"
|
||||
"count:42,"
|
||||
"score:3.14}"
|
||||
"<tool_call|>"
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].function.name == "set_status"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {"is_active": True, "count": 42, "score": 3.14}
|
||||
|
||||
def test_incomplete_tool_call(self, parser, mock_request):
|
||||
model_output = '<|tool_call>call:get_weather{location:<|"|>London'
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
# Incomplete — no <tool_call|> end marker, regex won't match
|
||||
assert result.tools_called is False
|
||||
assert result.content == model_output
|
||||
|
||||
def test_hyphenated_function_name(self, parser, mock_request):
|
||||
"""Ensure function names with hyphens are parsed correctly."""
|
||||
model_output = (
|
||||
'<|tool_call>call:get-weather{location:<|"|>London<|"|>}<tool_call|>'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.tool_calls[0].function.name == "get-weather"
|
||||
|
||||
def test_dotted_function_name(self, parser, mock_request):
|
||||
"""Ensure function names with dots are parsed correctly."""
|
||||
model_output = (
|
||||
'<|tool_call>call:weather.get{location:<|"|>London<|"|>}<tool_call|>'
|
||||
)
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.tool_calls[0].function.name == "weather.get"
|
||||
|
||||
def test_no_arguments(self, parser, mock_request):
|
||||
"""Tool calls with empty arguments."""
|
||||
model_output = "<|tool_call>call:get_status{}<tool_call|>"
|
||||
result = parser.extract_tool_calls(model_output, mock_request)
|
||||
|
||||
assert result.tools_called is True
|
||||
assert result.tool_calls[0].function.name == "get_status"
|
||||
args = json.loads(result.tool_calls[0].function.arguments)
|
||||
assert args == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming extraction tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStreamingExtraction:
|
||||
"""Tests for the streaming tool call extraction.
|
||||
|
||||
These simulate the token-by-token streaming that vLLM performs,
|
||||
feeding incremental text to extract_tool_calls_streaming() and
|
||||
verifying that the accumulated argument deltas form valid JSON.
|
||||
"""
|
||||
|
||||
def _simulate_streaming(
|
||||
self, parser: Gemma4ToolParser, mock_request: Any, chunks: list[str]
|
||||
) -> list[tuple[Any, str]]:
|
||||
"""Feed chunks through the streaming parser and collect results.
|
||||
|
||||
Returns a list of (delta_message, accumulated_text) tuples.
|
||||
"""
|
||||
results: list[tuple[Any, str]] = []
|
||||
previous_text: str = ""
|
||||
previous_token_ids: list[int] = []
|
||||
|
||||
for chunk in chunks:
|
||||
current_text = previous_text + chunk
|
||||
# Use token ID 48 for tool_call start, 49 for end, 0 otherwise
|
||||
delta_token_ids: list[int] = []
|
||||
if TOOL_CALL_START in chunk:
|
||||
delta_token_ids.append(48)
|
||||
elif TOOL_CALL_END in chunk:
|
||||
delta_token_ids.append(49)
|
||||
else:
|
||||
delta_token_ids.append(0)
|
||||
|
||||
current_token_ids = previous_token_ids + delta_token_ids
|
||||
|
||||
delta = parser.extract_tool_calls_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=chunk,
|
||||
previous_token_ids=tuple(previous_token_ids),
|
||||
current_token_ids=tuple(current_token_ids),
|
||||
delta_token_ids=tuple(delta_token_ids),
|
||||
request=mock_request,
|
||||
)
|
||||
results.append((delta, current_text))
|
||||
previous_text = current_text
|
||||
previous_token_ids = list(current_token_ids)
|
||||
|
||||
return results
|
||||
|
||||
def _collect_arguments(self, results):
|
||||
"""Collect all argument deltas from streaming results into one string."""
|
||||
args_text = ""
|
||||
for delta, _ in results:
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
func = tc.function if isinstance(tc.function, dict) else tc.function
|
||||
if isinstance(func, dict):
|
||||
arg = func.get("arguments", "")
|
||||
else:
|
||||
arg = getattr(func, "arguments", "") or ""
|
||||
if arg:
|
||||
args_text += arg
|
||||
return args_text
|
||||
|
||||
def _collect_function_name(self, results):
|
||||
"""Extract the function name from streaming results."""
|
||||
for delta, _ in results:
|
||||
if delta and delta.tool_calls:
|
||||
for tc in delta.tool_calls:
|
||||
func = tc.function if isinstance(tc.function, dict) else tc.function
|
||||
if isinstance(func, dict):
|
||||
name = func.get("name")
|
||||
else:
|
||||
name = getattr(func, "name", None)
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
def test_basic_streaming_single_tool(self, parser, mock_request):
|
||||
"""Simulate the exact streaming scenario from the bug report.
|
||||
|
||||
Model generates:
|
||||
<|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>}<tool_call|>
|
||||
|
||||
Expected: arguments should be valid JSON {"location": "Paris, France"}
|
||||
"""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Paris',
|
||||
", France",
|
||||
'<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
|
||||
# Verify function name
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_weather", f"Expected 'get_weather', got '{name}'"
|
||||
|
||||
# Verify arguments form valid JSON
|
||||
args_text = self._collect_arguments(results)
|
||||
assert args_text, "No arguments were streamed"
|
||||
parsed_args = json.loads(args_text)
|
||||
assert parsed_args == {"location": "Paris, France"}
|
||||
|
||||
def test_streaming_multi_arg(self, parser, mock_request):
|
||||
"""Streaming with multiple arguments."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Tokyo<|"|>,',
|
||||
'unit:<|"|>celsius<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_weather"
|
||||
|
||||
args_text = self._collect_arguments(results)
|
||||
assert args_text
|
||||
parsed_args = json.loads(args_text)
|
||||
assert parsed_args == {"location": "Tokyo", "unit": "celsius"}
|
||||
|
||||
def test_streaming_no_extra_brace(self, parser, mock_request):
|
||||
"""Verify the closing } is NOT leaked into arguments (Bug #2)."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>London<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
args_text = self._collect_arguments(results)
|
||||
assert args_text
|
||||
|
||||
# The args text must be valid JSON (no extra })
|
||||
parsed = json.loads(args_text)
|
||||
assert parsed == {"location": "London"}
|
||||
|
||||
# Specifically assert no double-brace
|
||||
assert args_text.count("}") <= 1, (
|
||||
f"Arguments contain extra closing brace: {args_text!r}"
|
||||
)
|
||||
|
||||
def test_streaming_no_unquoted_keys(self, parser, mock_request):
|
||||
"""Verify keys are properly quoted in JSON (Bug #1)."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Paris<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
args_text = self._collect_arguments(results)
|
||||
|
||||
# Must start with { and contain quoted key
|
||||
assert args_text.lstrip().startswith("{"), (
|
||||
f"Arguments don't start with '{{': {args_text!r}"
|
||||
)
|
||||
assert '"location"' in args_text, (
|
||||
f"Key 'location' not properly quoted: {args_text!r}"
|
||||
)
|
||||
|
||||
def test_streaming_name_no_call_prefix(self, parser, mock_request):
|
||||
"""Verify function name has no 'call:' prefix."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>Paris<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_weather"
|
||||
assert not name.startswith("call:"), f"Name has 'call:' prefix: {name!r}"
|
||||
|
||||
def test_streaming_text_before_tool_call(self, parser, mock_request):
|
||||
"""Text before tool call should be emitted as content."""
|
||||
chunks = [
|
||||
"Let me check ",
|
||||
"the weather. ",
|
||||
"<|tool_call>",
|
||||
"call:get_weather{",
|
||||
'location:<|"|>London<|"|>}',
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
|
||||
# First chunks should be content
|
||||
content_parts = []
|
||||
for delta, _ in results:
|
||||
if delta and delta.content:
|
||||
content_parts.append(delta.content)
|
||||
|
||||
assert "".join(content_parts).strip().startswith("Let me check")
|
||||
|
||||
def test_streaming_numeric_args(self, parser, mock_request):
|
||||
"""Streaming with numeric and boolean argument values."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:set_config{",
|
||||
"count:42,",
|
||||
"active:true}",
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
args_text = self._collect_arguments(results)
|
||||
if args_text:
|
||||
parsed_args = json.loads(args_text)
|
||||
assert parsed_args["count"] == 42
|
||||
assert parsed_args["active"] is True
|
||||
|
||||
def test_streaming_empty_args(self, parser, mock_request):
|
||||
"""Tool call with no arguments."""
|
||||
chunks = [
|
||||
"<|tool_call>",
|
||||
"call:get_status{}",
|
||||
"<tool_call|>",
|
||||
]
|
||||
|
||||
results = self._simulate_streaming(parser, mock_request, chunks)
|
||||
name = self._collect_function_name(results)
|
||||
assert name == "get_status"
|
||||
@@ -12,6 +12,7 @@ from .dual_chunk_rope import DualChunkRotaryEmbedding
|
||||
from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding
|
||||
from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding
|
||||
from .fope import FourierRotaryEmbedding
|
||||
from .gemma4_rope import Gemma4RotaryEmbedding
|
||||
from .linear_scaling_rope import LinearScalingRotaryEmbedding
|
||||
from .llama3_rope import Llama3RotaryEmbedding
|
||||
from .llama4_vision_rope import Llama4VisionRotaryEmbedding
|
||||
@@ -134,6 +135,17 @@ def get_rope(
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "proportional":
|
||||
# Proportional RoPE is used by Gemma4 for global (full) attention.
|
||||
# Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves.
|
||||
rotary_emb = Gemma4RotaryEmbedding(
|
||||
head_size,
|
||||
rotary_dim,
|
||||
max_position,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
elif scaling_type == "llama3":
|
||||
scaling_factor = rope_parameters["factor"]
|
||||
low_freq_factor = rope_parameters["low_freq_factor"]
|
||||
|
||||
84
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
Normal file
84
vllm/model_executor/layers/rotary_embedding/gemma4_rope.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Gemma4-specific Rotary Positional Embeddings (proportional scaling).
|
||||
|
||||
Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled
|
||||
by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when
|
||||
partial_rotary_factor < 1. The actual rotation uses standard neox-style
|
||||
rotate_half, matching HF transformers' apply_rotary_pos_emb.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from .base import RotaryEmbedding
|
||||
|
||||
|
||||
class Gemma4RotaryEmbedding(RotaryEmbedding):
|
||||
"""Gemma4 proportional RoPE.
|
||||
|
||||
Extends RotaryEmbedding (which provides standard neox-style rotation
|
||||
via ops.rotary_embedding CUDA kernel) but overrides the inv_freq
|
||||
computation to match HF's _compute_proportional_rope_parameters:
|
||||
- Frequency exponents use head_dim (not rotary_dim) as denominator
|
||||
- Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation)
|
||||
|
||||
When partial_rotary_factor=1.0 (the default for some variants), ALL dims are
|
||||
rotated and this is equivalent to standard RotaryEmbedding with
|
||||
head_dim-scaled frequencies.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
# Number of rotation angle pairs (from partial_rotary_factor)
|
||||
self.rope_angles = rotary_dim // 2
|
||||
# Non-rotated angle pairs per half
|
||||
self.nope_angles = (head_size // 2) - self.rope_angles
|
||||
|
||||
# Important: set rotary_dim = head_size so the base class's
|
||||
# forward_static applies rotation to ALL dims of the cos/sin cache.
|
||||
# The non-rotated dims will have cos=1, sin=0 (identity) thanks
|
||||
# to our _compute_inv_freq zero-padding.
|
||||
super().__init__(
|
||||
head_size,
|
||||
head_size, # rotary_dim = head_size (full application)
|
||||
max_position_embeddings,
|
||||
base,
|
||||
is_neox_style,
|
||||
dtype,
|
||||
)
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
"""Compute frequencies matching HF proportional RoPE.
|
||||
|
||||
Key difference from base: exponent denominator is head_size (not
|
||||
rotary_dim), and non-rotated dims are zero-padded.
|
||||
"""
|
||||
# HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim)
|
||||
freq_exponents = (
|
||||
torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size
|
||||
)
|
||||
inv_freq = 1.0 / (base**freq_exponents)
|
||||
|
||||
# Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0)
|
||||
if self.nope_angles > 0:
|
||||
inv_freq = torch.cat(
|
||||
[
|
||||
inv_freq,
|
||||
torch.zeros(self.nope_angles, dtype=torch.float),
|
||||
]
|
||||
)
|
||||
return inv_freq
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
|
||||
s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}"
|
||||
s += f", max_position_embeddings={self.max_position_embeddings}"
|
||||
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
|
||||
return s
|
||||
@@ -14,6 +14,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentio
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,6 +58,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
|
||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||
|
||||
|
||||
class Gemma4Config(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
"""Force unified attention backend for models with heterogeneous
|
||||
head dimensions.
|
||||
|
||||
Some Gemma4 variants use different head dimensions for
|
||||
sliding window (head_dim) vs full attention (global_head_dim) layers.
|
||||
When global_head_dim > 256, FlashAttention rejects those layers
|
||||
(head_size <= 256 kernel limit), causing vLLM to select a different
|
||||
backend for each layer type. This mixed-backend execution produces
|
||||
numerical divergence and output corruption.
|
||||
|
||||
The fix detects heterogeneous head dimensions from the model config
|
||||
and forces TRITON_ATTN (which has no head_size ceiling) for all
|
||||
layers when the user hasn't explicitly chosen a backend.
|
||||
|
||||
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
|
||||
require NixlConnector changes to support per-layer KV transfer
|
||||
with different head dimensions for prefill-decode disaggregation.
|
||||
"""
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
head_dim = getattr(hf_text_config, "head_dim", None)
|
||||
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
|
||||
|
||||
# Only force Triton when head dimensions actually differ AND the
|
||||
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
|
||||
# This avoids unnecessary backend forcing on smaller models where
|
||||
# the config carries global_head_dim but all layers can still use
|
||||
# the same FA backend.
|
||||
max_head_dim = max(head_dim or 0, global_head_dim or 0)
|
||||
if (
|
||||
head_dim is not None
|
||||
and global_head_dim is not None
|
||||
and head_dim != global_head_dim
|
||||
and max_head_dim > 256
|
||||
and vllm_config.attention_config.backend is None
|
||||
):
|
||||
from vllm.v1.attention.backends.registry import (
|
||||
AttentionBackendEnum,
|
||||
)
|
||||
|
||||
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
|
||||
logger.info(
|
||||
"Gemma4 model has heterogeneous head dimensions "
|
||||
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
|
||||
"backend to prevent mixed-backend numerical divergence.",
|
||||
head_dim,
|
||||
global_head_dim,
|
||||
)
|
||||
|
||||
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
@@ -668,6 +721,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"Gemma4ForCausalLM": Gemma4Config,
|
||||
"Gemma4ForConditionalGeneration": Gemma4Config,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
|
||||
1239
vllm/model_executor/models/gemma4.py
Normal file
1239
vllm/model_executor/models/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
1341
vllm/model_executor/models/gemma4_mm.py
Normal file
1341
vllm/model_executor/models/gemma4_mm.py
Normal file
File diff suppressed because it is too large
Load Diff
292
vllm/model_executor/models/gemma4_utils.py
Normal file
292
vllm/model_executor/models/gemma4_utils.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
"""Gemma4 output parsing utilities for offline inference.
|
||||
|
||||
Standalone functions that parse decoded model text to extract structured
|
||||
thinking content and tool calls from Gemma4 models. These are pure-Python
|
||||
utilities with zero heavy dependencies — they work on raw decoded strings
|
||||
from any inference backend (vLLM, HuggingFace, TGI, etc.).
|
||||
|
||||
Usage with vLLM offline inference::
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.models.gemma4_utils import (
|
||||
parse_output,
|
||||
parse_tool_calls,
|
||||
)
|
||||
|
||||
llm = LLM(model="google/gemma-4-it")
|
||||
outputs = llm.generate(prompt, SamplingParams(...))
|
||||
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
|
||||
|
||||
# Extract thinking / answer (works with or without enable_thinking)
|
||||
result = parse_output(text)
|
||||
print(result["thinking"]) # chain-of-thought or None
|
||||
print(result["answer"]) # final answer
|
||||
|
||||
# Extract tool calls
|
||||
tool_calls = parse_tool_calls(text)
|
||||
for tc in tool_calls:
|
||||
print(f"{tc['name']}({tc['arguments']})")
|
||||
|
||||
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
|
||||
do not need a transformers dependency for output parsing.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import regex as re
|
||||
|
||||
# ---- Thinking Mode Utility ----
|
||||
|
||||
# Thinking delimiter tokens as they appear in decoded text.
|
||||
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
|
||||
_THINKING_START_TAG = "<|channel>"
|
||||
_THINKING_END_TAG = "<channel|>"
|
||||
|
||||
# Sentinel tokens that may appear in decoded output.
|
||||
_TURN_END_TAG = "<turn|>"
|
||||
|
||||
|
||||
def parse_thinking_output(text: str) -> dict[str, str | None]:
|
||||
"""Parse decoded Gemma4 model output.
|
||||
|
||||
Use this on **all** Gemma4 output regardless of whether thinking mode
|
||||
was enabled. It handles three cases:
|
||||
|
||||
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
|
||||
``<channel|>`` to separate chain-of-thought from the answer and
|
||||
strips the ``thought\\n`` role label.
|
||||
2. **Thinking disabled, spurious label** — strips the bare
|
||||
``thought\\n`` prefix that some Gemma4 models emit even
|
||||
without thinking mode.
|
||||
3. **Clean output** — returns the text unchanged.
|
||||
|
||||
The answer text is always cleaned of trailing sentinel tokens
|
||||
(``<turn|>``, ``<eos>``, etc.).
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...)``).
|
||||
|
||||
Returns:
|
||||
A dict with keys:
|
||||
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
|
||||
thinking delimiters were found.
|
||||
- ``"answer"``: The final answer text.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output
|
||||
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> result = parse_thinking_output(output_text)
|
||||
>>> print(result["thinking"]) # chain-of-thought reasoning or None
|
||||
>>> print(result["answer"]) # final answer
|
||||
"""
|
||||
if _THINKING_END_TAG in text:
|
||||
parts = text.split(_THINKING_END_TAG, 1)
|
||||
thinking_block = parts[0]
|
||||
answer = _clean_answer(parts[1])
|
||||
|
||||
# Extract thinking content: strip the start tag if present
|
||||
if _THINKING_START_TAG in thinking_block:
|
||||
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
|
||||
else:
|
||||
thinking = thinking_block
|
||||
|
||||
# Strip the "thought\n" channel role label the model emits inside
|
||||
# <|channel>thought\n...<channel|> (analogous to "user\n" in
|
||||
# <|turn>user\n...<turn|>).
|
||||
thinking = _strip_thought_label(thinking.strip())
|
||||
thinking = thinking.strip()
|
||||
|
||||
return {"thinking": thinking, "answer": answer}
|
||||
|
||||
# No thinking delimiters found.
|
||||
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
|
||||
# emit even without thinking mode enabled, then clean trailing tokens.
|
||||
answer = _strip_thought_label(text)
|
||||
answer = _clean_answer(answer)
|
||||
return {"thinking": None, "answer": answer}
|
||||
|
||||
|
||||
def _strip_thought_label(text: str) -> str:
|
||||
"""Strip the spurious ``thought\\n`` label from the start of text.
|
||||
|
||||
Only strips when ``thought`` appears as the very first word followed by
|
||||
a newline — preserving the word ``thought`` in any other context.
|
||||
"""
|
||||
if text.startswith("thought\n"):
|
||||
return text[len("thought\n") :]
|
||||
return text
|
||||
|
||||
|
||||
def _clean_answer(text: str) -> str:
|
||||
"""Clean trailing sentinel tokens from the answer text.
|
||||
|
||||
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
|
||||
model appends at the end of its response.
|
||||
"""
|
||||
text = text.strip()
|
||||
# Strip trailing <turn|> (Gemma4 turn-end marker)
|
||||
if text.endswith(_TURN_END_TAG):
|
||||
text = text[: -len(_TURN_END_TAG)].rstrip()
|
||||
# Strip trailing <eos> if present
|
||||
if text.endswith("<eos>"):
|
||||
text = text[:-5].rstrip()
|
||||
return text
|
||||
|
||||
|
||||
# ---- Tool Call Parsing Utility ----
|
||||
#
|
||||
# NOTE: For the OpenAI-compatible API server tool parser (streaming +
|
||||
# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py.
|
||||
# This module provides offline inference utilities for direct user import.
|
||||
|
||||
# Tool call delimiter tokens as they appear in decoded text.
|
||||
# Standard format: <|tool_call>call:name{args}<tool_call|>
|
||||
_TOOL_CALL_START_TAG = "<|tool_call>"
|
||||
_TOOL_CALL_END_TAG = "<tool_call|>"
|
||||
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
|
||||
|
||||
# Gemma4 escape token as it appears in decoded text.
|
||||
_ESCAPE_TOKEN = '<|"|>'
|
||||
|
||||
|
||||
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
|
||||
"""Parse tool call arguments from the Gemma4 compact format.
|
||||
|
||||
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
|
||||
to heuristic key-value extraction. Also tolerates the slightly different
|
||||
``key: "value"`` format (space + plain quotes) that some chat templates
|
||||
produce.
|
||||
|
||||
Args:
|
||||
args_str: Raw argument string from inside ``call:name{...}``.
|
||||
|
||||
Returns:
|
||||
Dictionary of argument name → value.
|
||||
"""
|
||||
if not args_str or not args_str.strip():
|
||||
return {}
|
||||
|
||||
# Replace Gemma4 escape tokens with standard quotes.
|
||||
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
|
||||
|
||||
# Try JSON parsing first (handles nested values, arrays, etc.).
|
||||
try:
|
||||
parsed = json.loads("{" + cleaned + "}")
|
||||
# Ensure all values are strings for consistency.
|
||||
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: extract key:"value" pairs (allow optional space after colon).
|
||||
arguments = {}
|
||||
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
|
||||
arguments[key] = value
|
||||
|
||||
if not arguments:
|
||||
# Last resort: extract key:value pairs (unquoted).
|
||||
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
|
||||
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
|
||||
"""Parse tool calls from decoded Gemma4 model output.
|
||||
|
||||
Uses a tiered parsing strategy to handle known output variations in
|
||||
Gemma4 models, which may emit
|
||||
non-standard tool call formats.
|
||||
|
||||
Parsing tiers:
|
||||
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
|
||||
(special token IDs 48/49 in decoded text)
|
||||
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
|
||||
patterns, including ``<call>name{args}`` (fragmented tokens from
|
||||
multimodal inputs)
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...,
|
||||
skip_special_tokens=False)``).
|
||||
strict: If ``True``, only match the standard ``<|tool_call>`` format.
|
||||
If ``False`` (default), also try fallback patterns for
|
||||
known Gemma4 output variations.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each with keys:
|
||||
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
|
||||
- ``"arguments"``: A dict of argument name → value.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import (
|
||||
... parse_tool_calls
|
||||
... )
|
||||
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> tool_calls = parse_tool_calls(output)
|
||||
>>> for tc in tool_calls:
|
||||
... print(f"Call: {tc['name']}({tc['arguments']})")
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Tier 1: Standard format with special tokens.
|
||||
# <|tool_call>call:name{args}<tool_call|>
|
||||
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
|
||||
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
|
||||
for match in re.finditer(standard_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
if results or strict:
|
||||
return results
|
||||
|
||||
# Tier 2: Fallback for known Gemma4 output variations.
|
||||
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
|
||||
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
|
||||
for match in re.finditer(fallback_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def has_tool_response_tag(text: str) -> bool:
|
||||
"""Check if model output properly ends with a tool response tag.
|
||||
|
||||
Some Gemma4 models sometimes emit ``<eos>`` instead of
|
||||
``<|tool_response>`` after a tool call. This helper detects
|
||||
whether the model used the proper termination, so callers can
|
||||
decide whether to inject ``<|tool_response>`` into the next prompt.
|
||||
|
||||
Args:
|
||||
text: Decoded model output text.
|
||||
|
||||
Returns:
|
||||
``True`` if the output ends with ``<|tool_response>``
|
||||
(proper behavior), ``False`` otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import (
|
||||
... has_tool_response_tag
|
||||
... )
|
||||
>>> if not has_tool_response_tag(model_output):
|
||||
... # Model used <eos> instead — inject <|tool_response> manually
|
||||
... next_prompt = "<|tool_response>" + tool_result
|
||||
"""
|
||||
stripped = text.rstrip()
|
||||
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
|
||||
@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
|
||||
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
|
||||
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
@@ -381,6 +382,7 @@ _MULTIMODAL_MODELS = {
|
||||
"gemma3n_mm",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
),
|
||||
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
|
||||
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
|
||||
|
||||
@@ -233,8 +233,15 @@ class AutoWeightsLoader:
|
||||
):
|
||||
"""
|
||||
Add tensor names that are not in the model params that may be in the
|
||||
safetensors, e.g., batch normalization stats.
|
||||
safetensors, e.g., batch normalization stats and registered buffers.
|
||||
"""
|
||||
# Add persistent registered buffers.
|
||||
# Non-persistent buffers are excluded, matching PyTorch state_dict().
|
||||
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
|
||||
for buf_name, buf in module.named_buffers(recurse=False):
|
||||
if buf_name not in child_params and buf_name not in non_persistent:
|
||||
child_params[buf_name] = buf
|
||||
|
||||
if isinstance(
|
||||
module,
|
||||
(
|
||||
|
||||
@@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = {
|
||||
"ernie45_reasoning_parser",
|
||||
"Ernie45ReasoningParser",
|
||||
),
|
||||
"gemma4": (
|
||||
"gemma4_reasoning_parser",
|
||||
"Gemma4ReasoningParser",
|
||||
),
|
||||
"glm45": (
|
||||
"deepseek_v3_reasoning_parser",
|
||||
"DeepSeekV3ReasoningWithThinkingParser",
|
||||
|
||||
193
vllm/reasoning/gemma4_reasoning_parser.py
Normal file
193
vllm/reasoning/gemma4_reasoning_parser.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
|
||||
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import ResponsesRequest
|
||||
|
||||
# Role label that Gemma4 emits at the start of the thinking channel.
|
||||
# The model generates: <|channel>thought\n...reasoning...<channel|>
|
||||
# This prefix must be stripped to expose only the actual reasoning content.
|
||||
_THOUGHT_PREFIX = "thought\n"
|
||||
|
||||
|
||||
class Gemma4ReasoningParser(BaseThinkingReasoningParser):
|
||||
"""
|
||||
Reasoning parser for Google Gemma4 thinking models.
|
||||
|
||||
Gemma4 uses <|channel>...<channel|> tokens to delimit reasoning/thinking
|
||||
content within its output. Thinking mode is activated by passing
|
||||
``enable_thinking=True`` in the chat template kwargs, which injects a
|
||||
system turn containing <|think|> (token 98) to trigger chain-of-thought
|
||||
reasoning.
|
||||
|
||||
Output pattern when thinking is enabled::
|
||||
|
||||
<|channel>thought
|
||||
...chain of thought reasoning...<channel|>
|
||||
Final answer text here.
|
||||
|
||||
The ``thought\\n`` role label inside the channel delimiters is a
|
||||
structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``).
|
||||
This parser strips it so that downstream consumers see only the
|
||||
actual reasoning text, consistent with the offline parser
|
||||
(``vllm.reasoning.gemma4_utils._strip_thought_label``).
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike, *args, **kwargs):
|
||||
super().__init__(tokenizer, *args, **kwargs)
|
||||
# Instance state for streaming prefix stripping.
|
||||
# Tracks only the reasoning text received from the base parser,
|
||||
# independent of current_text (which may contain pre-reasoning
|
||||
# content and lacks special token text due to
|
||||
# skip_special_tokens=True).
|
||||
self._reasoning_text: str = ""
|
||||
self._prefix_stripped: bool = False
|
||||
|
||||
@property
|
||||
def start_token(self) -> str:
|
||||
"""The token that starts reasoning content."""
|
||||
return "<|channel>"
|
||||
|
||||
@property
|
||||
def end_token(self) -> str:
|
||||
"""The token that ends reasoning content."""
|
||||
return "<channel|>"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming path
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning(
|
||||
self,
|
||||
model_output: str,
|
||||
request: "ChatCompletionRequest | ResponsesRequest",
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract reasoning, stripping the ``thought\\n`` role label."""
|
||||
if self.start_token not in model_output and self.end_token not in model_output:
|
||||
# Default to content history if no tags are present
|
||||
# (or if they were stripped)
|
||||
return None, model_output
|
||||
|
||||
reasoning, content = super().extract_reasoning(model_output, request)
|
||||
if reasoning is not None:
|
||||
reasoning = _strip_thought_label(reasoning)
|
||||
return reasoning, content
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming path
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_reasoning_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
) -> DeltaMessage | None:
|
||||
"""Extract streaming reasoning, stripping ``thought\\n`` from the
|
||||
first reasoning delta(s).
|
||||
|
||||
The ``thought\\n`` prefix may arrive as a single delta or split
|
||||
across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We
|
||||
buffer early reasoning tokens until we can determine whether the
|
||||
prefix is present, then emit the buffered content minus the
|
||||
prefix.
|
||||
|
||||
Unlike the previous implementation which reconstructed accumulated
|
||||
reasoning from ``current_text``, this uses instance state
|
||||
(``_reasoning_text``) to track only the reasoning content returned
|
||||
by the base parser. This is necessary because
|
||||
``skip_special_tokens=True`` (the vLLM default) causes the
|
||||
``<|channel>`` delimiter to be invisible in ``current_text``,
|
||||
making it impossible to separate pre-reasoning content from
|
||||
reasoning content via string matching.
|
||||
"""
|
||||
result = super().extract_reasoning_streaming(
|
||||
previous_text,
|
||||
current_text,
|
||||
delta_text,
|
||||
previous_token_ids,
|
||||
current_token_ids,
|
||||
delta_token_ids,
|
||||
)
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
if result.reasoning is None:
|
||||
return result
|
||||
|
||||
# Accumulate ONLY the reasoning text from base parser results.
|
||||
# This is immune to pre-reasoning content pollution.
|
||||
self._reasoning_text += result.reasoning
|
||||
|
||||
# Once the prefix has been handled, all subsequent reasoning
|
||||
# deltas pass through unchanged.
|
||||
if self._prefix_stripped:
|
||||
return result
|
||||
|
||||
# ---- Prefix stripping logic ----
|
||||
|
||||
# Case 1: We've accumulated enough to confirm the prefix is
|
||||
# present. Strip it and pass through the remainder.
|
||||
if self._reasoning_text.startswith(_THOUGHT_PREFIX):
|
||||
prefix_len = len(_THOUGHT_PREFIX)
|
||||
# How much reasoning was accumulated before this delta?
|
||||
prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning)
|
||||
if prev_reasoning_len >= prefix_len:
|
||||
# Prefix was already consumed by prior deltas; this
|
||||
# delta is entirely real content — pass through.
|
||||
self._prefix_stripped = True
|
||||
return result
|
||||
else:
|
||||
# Part or all of the prefix is in this delta.
|
||||
chars_of_prefix_in_delta = prefix_len - prev_reasoning_len
|
||||
stripped = result.reasoning[chars_of_prefix_in_delta:]
|
||||
if stripped:
|
||||
self._prefix_stripped = True
|
||||
result.reasoning = stripped
|
||||
return result
|
||||
else:
|
||||
# This entire delta was prefix — suppress it.
|
||||
# Don't set _prefix_stripped yet; there may be more
|
||||
# prefix chars to consume in the next delta.
|
||||
if len(self._reasoning_text) >= prefix_len:
|
||||
self._prefix_stripped = True
|
||||
return None
|
||||
|
||||
# Case 2: Accumulated text is a strict prefix of
|
||||
# _THOUGHT_PREFIX (e.g. we've only seen "thou" so far).
|
||||
# Buffer by suppressing — we can't yet tell if this will
|
||||
# become the full prefix or diverge.
|
||||
if _THOUGHT_PREFIX.startswith(self._reasoning_text):
|
||||
return None
|
||||
|
||||
# Case 3: Accumulated text doesn't match the thought prefix
|
||||
# at all. This means prior deltas were buffered (suppressed
|
||||
# by Case 2) but the text diverged. Re-emit the full
|
||||
# accumulated text to avoid data loss.
|
||||
self._prefix_stripped = True
|
||||
result.reasoning = self._reasoning_text
|
||||
return result
|
||||
|
||||
|
||||
def _strip_thought_label(text: str) -> str:
|
||||
"""Remove the ``thought\\n`` role label from the beginning of text.
|
||||
|
||||
Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the
|
||||
offline parser.
|
||||
"""
|
||||
if text.startswith(_THOUGHT_PREFIX):
|
||||
return text[len(_THOUGHT_PREFIX) :]
|
||||
return text
|
||||
130
vllm/reasoning/gemma4_utils.py
Normal file
130
vllm/reasoning/gemma4_utils.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
"""Gemma4 thinking/reasoning output parsing utilities for offline inference.
|
||||
|
||||
Standalone functions that parse decoded model text to extract structured
|
||||
thinking content from Gemma4 models. These are pure-Python utilities with
|
||||
zero heavy dependencies — they work on raw decoded strings from any
|
||||
inference backend (vLLM, HuggingFace, TGI, etc.).
|
||||
|
||||
For the OpenAI-compatible API reasoning parser (streaming +
|
||||
non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``.
|
||||
For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``.
|
||||
|
||||
Usage with vLLM offline inference::
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.reasoning.gemma4_utils import parse_thinking_output
|
||||
|
||||
llm = LLM(model="google/gemma-4-it")
|
||||
outputs = llm.generate(prompt, SamplingParams(...))
|
||||
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
|
||||
|
||||
# Extract thinking / answer (works with or without enable_thinking)
|
||||
result = parse_thinking_output(text)
|
||||
print(result["thinking"]) # chain-of-thought or None
|
||||
print(result["answer"]) # final answer
|
||||
|
||||
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
|
||||
do not need a transformers dependency for output parsing.
|
||||
"""
|
||||
|
||||
# ---- Thinking Mode Utility ----
|
||||
|
||||
# Thinking delimiter tokens as they appear in decoded text.
|
||||
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
|
||||
_THINKING_START_TAG = "<|channel>"
|
||||
_THINKING_END_TAG = "<channel|>"
|
||||
|
||||
# Sentinel tokens that may appear in decoded output.
|
||||
_TURN_END_TAG = "<turn|>"
|
||||
|
||||
|
||||
def parse_thinking_output(text: str) -> dict[str, str | None]:
|
||||
"""Parse decoded Gemma4 model output.
|
||||
|
||||
Use this on **all** Gemma4 output regardless of whether thinking mode
|
||||
was enabled. It handles three cases:
|
||||
|
||||
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
|
||||
``<channel|>`` to separate chain-of-thought from the answer and
|
||||
strips the ``thought\\n`` role label.
|
||||
2. **Thinking disabled, spurious label** — strips the bare
|
||||
``thought\\n`` prefix that some Gemma4 models emit even
|
||||
without thinking mode.
|
||||
3. **Clean output** — returns the text unchanged.
|
||||
|
||||
The answer text is always cleaned of trailing sentinel tokens
|
||||
(``<turn|>``, ``<eos>``, etc.).
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...)``).
|
||||
|
||||
Returns:
|
||||
A dict with keys:
|
||||
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
|
||||
thinking delimiters were found.
|
||||
- ``"answer"``: The final answer text.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.reasoning.gemma4_utils import parse_thinking_output
|
||||
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> result = parse_thinking_output(output_text)
|
||||
>>> print(result["thinking"]) # chain-of-thought reasoning or None
|
||||
>>> print(result["answer"]) # final answer
|
||||
"""
|
||||
if _THINKING_END_TAG in text:
|
||||
parts = text.split(_THINKING_END_TAG, 1)
|
||||
thinking_block = parts[0]
|
||||
answer = _clean_answer(parts[1])
|
||||
|
||||
# Extract thinking content: strip the start tag if present
|
||||
if _THINKING_START_TAG in thinking_block:
|
||||
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
|
||||
else:
|
||||
thinking = thinking_block
|
||||
|
||||
# Strip the "thought\n" channel role label the model emits inside
|
||||
# <|channel>thought\n...<channel|> (analogous to "user\n" in
|
||||
# <|turn>user\n...<turn|>).
|
||||
thinking = _strip_thought_label(thinking.strip())
|
||||
thinking = thinking.strip()
|
||||
|
||||
return {"thinking": thinking, "answer": answer}
|
||||
|
||||
# No thinking delimiters found.
|
||||
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
|
||||
# emit even without thinking mode enabled, then clean trailing tokens.
|
||||
answer = _strip_thought_label(text)
|
||||
answer = _clean_answer(answer)
|
||||
return {"thinking": None, "answer": answer}
|
||||
|
||||
|
||||
def _strip_thought_label(text: str) -> str:
|
||||
"""Strip the spurious ``thought\\n`` label from the start of text.
|
||||
|
||||
Only strips when ``thought`` appears as the very first word followed by
|
||||
a newline — preserving the word ``thought`` in any other context.
|
||||
"""
|
||||
if text.startswith("thought\n"):
|
||||
return text[len("thought\n") :]
|
||||
return text
|
||||
|
||||
|
||||
def _clean_answer(text: str) -> str:
|
||||
"""Clean trailing sentinel tokens from the answer text.
|
||||
|
||||
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
|
||||
model appends at the end of its response.
|
||||
"""
|
||||
text = text.strip()
|
||||
# Strip trailing <turn|> (Gemma4 turn-end marker)
|
||||
if text.endswith(_TURN_END_TAG):
|
||||
text = text[: -len(_TURN_END_TAG)].rstrip()
|
||||
# Strip trailing <eos> if present
|
||||
if text.endswith("<eos>"):
|
||||
text = text[:-5].rstrip()
|
||||
return text
|
||||
@@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = {
|
||||
"functiongemma_tool_parser",
|
||||
"FunctionGemmaToolParser",
|
||||
),
|
||||
"gemma4": (
|
||||
"gemma4_tool_parser",
|
||||
"Gemma4ToolParser",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
724
vllm/tool_parsers/gemma4_tool_parser.py
Normal file
724
vllm/tool_parsers/gemma4_tool_parser.py
Normal file
@@ -0,0 +1,724 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Tool call parser for Google Gemma4 models.
|
||||
|
||||
Gemma4 uses a custom serialization format (not JSON) for tool calls::
|
||||
|
||||
<|tool_call>call:func_name{key:<|"|>value<|"|>,num:42}<tool_call|>
|
||||
|
||||
Strings are delimited by ``<|"|>`` (token 52), keys are unquoted, and
|
||||
multiple tool calls are concatenated without separators.
|
||||
|
||||
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` are set.
|
||||
|
||||
For offline inference tool call parsing (direct ``tokenizer.decode()`` output),
|
||||
see ``vllm.tool_parsers.gemma4_utils.parse_tool_calls``.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
|
||||
import regex as re
|
||||
|
||||
from vllm.entrypoints.chat_utils import make_tool_call_id
|
||||
from vllm.entrypoints.openai.chat_completion.protocol import (
|
||||
ChatCompletionRequest,
|
||||
)
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
DeltaFunctionCall,
|
||||
DeltaMessage,
|
||||
DeltaToolCall,
|
||||
ExtractedToolCallInformation,
|
||||
FunctionCall,
|
||||
ToolCall,
|
||||
)
|
||||
from vllm.entrypoints.openai.responses.protocol import (
|
||||
ResponsesRequest,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.tool_parsers.abstract_tool_parser import ToolParser
|
||||
from vllm.tool_parsers.utils import find_common_prefix
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Gemma4 special tokens for tool calls
|
||||
TOOL_CALL_START = "<|tool_call>"
|
||||
TOOL_CALL_END = "<tool_call|>"
|
||||
STRING_DELIM = '<|"|>'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gemma4 argument parser (used by both streaming and non-streaming paths)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_gemma4_value(value_str: str) -> object:
|
||||
"""Parse a single Gemma4 value (after key:) into a Python object."""
|
||||
value_str = value_str.strip()
|
||||
if not value_str:
|
||||
return value_str
|
||||
|
||||
# Boolean
|
||||
if value_str == "true":
|
||||
return True
|
||||
if value_str == "false":
|
||||
return False
|
||||
|
||||
# Number (int or float)
|
||||
try:
|
||||
if "." in value_str:
|
||||
return float(value_str)
|
||||
return int(value_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Bare string (no <|"|> delimiters — shouldn't happen but be safe)
|
||||
return value_str
|
||||
|
||||
|
||||
def _parse_gemma4_args(args_str: str) -> dict:
|
||||
"""Parse Gemma4's custom key:value format into a Python dict.
|
||||
|
||||
Format examples::
|
||||
|
||||
location:<|"|>Tokyo<|"|>
|
||||
location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>
|
||||
count:42,flag:true
|
||||
nested:{inner_key:<|"|>val<|"|>}
|
||||
items:[<|"|>a<|"|>,<|"|>b<|"|>]
|
||||
|
||||
Returns a dict ready for ``json.dumps()``.
|
||||
"""
|
||||
if not args_str or not args_str.strip():
|
||||
return {}
|
||||
|
||||
result: dict = {}
|
||||
i = 0
|
||||
n = len(args_str)
|
||||
|
||||
while i < n:
|
||||
# Skip whitespace and commas
|
||||
while i < n and args_str[i] in (" ", ",", "\n", "\t"):
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
# Parse key (unquoted, ends at ':')
|
||||
key_start = i
|
||||
while i < n and args_str[i] != ":":
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
key = args_str[key_start:i].strip()
|
||||
i += 1 # skip ':'
|
||||
|
||||
# Parse value
|
||||
if i >= n:
|
||||
result[key] = ""
|
||||
break
|
||||
|
||||
# Skip whitespace after ':'
|
||||
while i < n and args_str[i] in (" ", "\n", "\t"):
|
||||
i += 1
|
||||
if i >= n:
|
||||
result[key] = ""
|
||||
break
|
||||
|
||||
# String value: <|"|>...<|"|>
|
||||
if args_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
val_start = i
|
||||
end_pos = args_str.find(STRING_DELIM, i)
|
||||
if end_pos == -1:
|
||||
# Unterminated string — take rest
|
||||
result[key] = args_str[val_start:]
|
||||
break
|
||||
result[key] = args_str[val_start:end_pos]
|
||||
i = end_pos + len(STRING_DELIM)
|
||||
|
||||
# Nested object: {...}
|
||||
elif args_str[i] == "{":
|
||||
depth = 1
|
||||
obj_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if args_str[i:].startswith(STRING_DELIM):
|
||||
# Skip over string contents to avoid counting { inside strings
|
||||
i += len(STRING_DELIM)
|
||||
next_delim = args_str.find(STRING_DELIM, i)
|
||||
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
|
||||
continue
|
||||
if args_str[i] == "{":
|
||||
depth += 1
|
||||
elif args_str[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
result[key] = _parse_gemma4_args(args_str[obj_start : i - 1])
|
||||
|
||||
# Array: [...]
|
||||
elif args_str[i] == "[":
|
||||
depth = 1
|
||||
arr_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if args_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
next_delim = args_str.find(STRING_DELIM, i)
|
||||
i = n if next_delim == -1 else next_delim + len(STRING_DELIM)
|
||||
continue
|
||||
if args_str[i] == "[":
|
||||
depth += 1
|
||||
elif args_str[i] == "]":
|
||||
depth -= 1
|
||||
i += 1
|
||||
arr_content = args_str[arr_start : i - 1]
|
||||
result[key] = _parse_gemma4_array(arr_content)
|
||||
|
||||
# Bare value (number, boolean, etc.)
|
||||
else:
|
||||
val_start = i
|
||||
while i < n and args_str[i] not in (",", "}", "]"):
|
||||
i += 1
|
||||
result[key] = _parse_gemma4_value(args_str[val_start:i])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _parse_gemma4_array(arr_str: str) -> list:
|
||||
"""Parse a Gemma4 array content string into a Python list."""
|
||||
items: list = []
|
||||
i = 0
|
||||
n = len(arr_str)
|
||||
|
||||
while i < n:
|
||||
while i < n and arr_str[i] in (" ", ",", "\n", "\t"):
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
|
||||
# String element
|
||||
if arr_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
end_pos = arr_str.find(STRING_DELIM, i)
|
||||
if end_pos == -1:
|
||||
items.append(arr_str[i:])
|
||||
break
|
||||
items.append(arr_str[i:end_pos])
|
||||
i = end_pos + len(STRING_DELIM)
|
||||
|
||||
# Nested object
|
||||
elif arr_str[i] == "{":
|
||||
depth = 1
|
||||
obj_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if arr_str[i:].startswith(STRING_DELIM):
|
||||
i += len(STRING_DELIM)
|
||||
nd = arr_str.find(STRING_DELIM, i)
|
||||
i = nd + len(STRING_DELIM) if nd != -1 else n
|
||||
continue
|
||||
if arr_str[i] == "{":
|
||||
depth += 1
|
||||
elif arr_str[i] == "}":
|
||||
depth -= 1
|
||||
i += 1
|
||||
items.append(_parse_gemma4_args(arr_str[obj_start : i - 1]))
|
||||
|
||||
# Nested array
|
||||
elif arr_str[i] == "[":
|
||||
depth = 1
|
||||
sub_start = i + 1
|
||||
i += 1
|
||||
while i < n and depth > 0:
|
||||
if arr_str[i] == "[":
|
||||
depth += 1
|
||||
elif arr_str[i] == "]":
|
||||
depth -= 1
|
||||
i += 1
|
||||
items.append(_parse_gemma4_array(arr_str[sub_start : i - 1]))
|
||||
|
||||
# Bare value
|
||||
else:
|
||||
val_start = i
|
||||
while i < n and arr_str[i] not in (",", "]"):
|
||||
i += 1
|
||||
items.append(_parse_gemma4_value(arr_str[val_start:i]))
|
||||
|
||||
return items
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Gemma4ToolParser(ToolParser):
|
||||
"""
|
||||
Tool call parser for Google Gemma4 models.
|
||||
|
||||
Handles the Gemma4 function call format::
|
||||
|
||||
<|tool_call>call:func_name{key:<|"|>value<|"|>}<tool_call|>
|
||||
|
||||
Used when ``--enable-auto-tool-choice --tool-call-parser gemma4``
|
||||
are set.
|
||||
|
||||
Streaming strategy: **accumulate-then-parse-then-diff**
|
||||
|
||||
Instead of trying to convert Gemma4's custom format to JSON
|
||||
token-by-token (which fails because Gemma4 uses bare keys, custom
|
||||
delimiters, and structural braces that differ from JSON), this parser:
|
||||
|
||||
1. Accumulates the raw Gemma4 argument string during streaming
|
||||
2. Parses it with ``_parse_gemma4_args()`` into a Python dict
|
||||
3. Converts to JSON with ``json.dumps()``
|
||||
4. Diffs against the previously-streamed JSON string
|
||||
5. Emits only the new JSON fragment as the delta
|
||||
|
||||
This follows the same pattern used by FunctionGemma, Hermes, and Llama
|
||||
tool parsers.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: TokenizerLike):
|
||||
super().__init__(tokenizer)
|
||||
|
||||
if not self.model_tokenizer:
|
||||
raise ValueError(
|
||||
"The model tokenizer must be passed to the ToolParser "
|
||||
"constructor during construction."
|
||||
)
|
||||
|
||||
# Token strings
|
||||
self.tool_call_start_token = TOOL_CALL_START
|
||||
self.tool_call_end_token = TOOL_CALL_END
|
||||
|
||||
# Token IDs
|
||||
self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START)
|
||||
self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END)
|
||||
|
||||
if self.tool_call_start_token_id is None:
|
||||
raise RuntimeError(
|
||||
"Gemma4 ToolParser could not locate the tool call start "
|
||||
f"token '{TOOL_CALL_START}' in the tokenizer!"
|
||||
)
|
||||
|
||||
# Regex for non-streaming: extract complete tool calls.
|
||||
# Supports function names with letters, digits, underscores,
|
||||
# hyphens, and dots (e.g. "get-weather", "module.func").
|
||||
self.tool_call_regex = re.compile(
|
||||
r"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}<tool_call\|>",
|
||||
re.DOTALL,
|
||||
)
|
||||
|
||||
# Streaming state — reset per-request via _reset_streaming_state()
|
||||
self._reset_streaming_state()
|
||||
|
||||
# Delta buffer for handling multi-token special sequences
|
||||
self.buffered_delta_text = ""
|
||||
|
||||
def _reset_streaming_state(self) -> None:
|
||||
"""Reset all streaming state for a new request."""
|
||||
self.current_tool_id = -1
|
||||
self.current_tool_name_sent = False
|
||||
self.prev_tool_call_arr: list[dict] = []
|
||||
self.streamed_args_for_tool: list[str] = []
|
||||
|
||||
def adjust_request(
|
||||
self, request: ChatCompletionRequest | ResponsesRequest
|
||||
) -> ChatCompletionRequest | ResponsesRequest:
|
||||
request = super().adjust_request(request)
|
||||
if (
|
||||
isinstance(request, ChatCompletionRequest)
|
||||
and request.tools
|
||||
and request.tool_choice != "none"
|
||||
):
|
||||
# Don't skip special tokens — <|tool_call> etc. are needed
|
||||
request.skip_special_tokens = False
|
||||
return request
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Delta buffering for multi-token special sequences
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _buffer_delta_text(self, delta_text: str) -> str:
|
||||
"""Buffer incoming delta text to handle multi-token special sequences.
|
||||
|
||||
Accumulates partial tokens that could be the start of
|
||||
``<|tool_call>`` or ``<tool_call|>`` and only flushes them
|
||||
when the complete sequence is recognized or the sequence breaks.
|
||||
|
||||
This prevents partial special tokens (e.g., ``<|tool``) from being
|
||||
emitted prematurely as content text.
|
||||
"""
|
||||
combined = self.buffered_delta_text + delta_text
|
||||
|
||||
# Check if combined ends with a complete special token
|
||||
if combined.endswith(TOOL_CALL_START) or combined.endswith(TOOL_CALL_END):
|
||||
self.buffered_delta_text = ""
|
||||
return combined
|
||||
|
||||
# Check if combined ends with a partial prefix of a special token
|
||||
for tag in [TOOL_CALL_START, TOOL_CALL_END]:
|
||||
for i in range(1, len(tag)):
|
||||
if combined.endswith(tag[:i]):
|
||||
self.buffered_delta_text = combined[-i:]
|
||||
return combined[:-i]
|
||||
|
||||
# No partial match — flush everything
|
||||
self.buffered_delta_text = ""
|
||||
return combined
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Non-streaming extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_tool_calls(
|
||||
self,
|
||||
model_output: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ExtractedToolCallInformation:
|
||||
if self.tool_call_start_token not in model_output:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
try:
|
||||
matches = self.tool_call_regex.findall(model_output)
|
||||
if not matches:
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCall] = []
|
||||
for func_name, args_str in matches:
|
||||
arguments = _parse_gemma4_args(args_str)
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=func_name,
|
||||
arguments=json.dumps(arguments, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Content = text before first tool call (if any)
|
||||
content_end = model_output.find(self.tool_call_start_token)
|
||||
content = model_output[:content_end].strip() if content_end > 0 else None
|
||||
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=True,
|
||||
tool_calls=tool_calls,
|
||||
content=content if content else None,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error extracting tool calls from Gemma4 response")
|
||||
return ExtractedToolCallInformation(
|
||||
tools_called=False, tool_calls=[], content=model_output
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Streaming extraction — accumulate-then-parse-then-diff
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def extract_tool_calls_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
previous_token_ids: Sequence[int],
|
||||
current_token_ids: Sequence[int],
|
||||
delta_token_ids: Sequence[int],
|
||||
request: ChatCompletionRequest,
|
||||
) -> DeltaMessage | None:
|
||||
# Buffer delta text to handle multi-token special sequences
|
||||
delta_text = self._buffer_delta_text(delta_text)
|
||||
# Reconstruct current_text after buffering to stay in sync
|
||||
current_text = previous_text + delta_text
|
||||
|
||||
# If no tool call token seen yet, emit as content
|
||||
if self.tool_call_start_token not in current_text:
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
return None
|
||||
|
||||
try:
|
||||
return self._extract_streaming(
|
||||
previous_text=previous_text,
|
||||
current_text=current_text,
|
||||
delta_text=delta_text,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error in Gemma4 streaming tool call extraction")
|
||||
return None
|
||||
|
||||
def _extract_streaming(
|
||||
self,
|
||||
previous_text: str,
|
||||
current_text: str,
|
||||
delta_text: str,
|
||||
) -> DeltaMessage | None:
|
||||
"""Tag-counting streaming parser.
|
||||
|
||||
Uses the proven approach from FunctionGemma/Hermes: count start/end
|
||||
tags in previous vs current text to determine phase, then
|
||||
accumulate-parse-diff for arguments.
|
||||
|
||||
Format: ``<|tool_call>call:name{args}<tool_call|>``
|
||||
"""
|
||||
start_count = current_text.count(self.tool_call_start_token)
|
||||
end_count = current_text.count(self.tool_call_end_token)
|
||||
prev_start_count = previous_text.count(self.tool_call_start_token)
|
||||
prev_end_count = previous_text.count(self.tool_call_end_token)
|
||||
|
||||
# Case 1: Not inside any tool call — emit as content
|
||||
if (
|
||||
start_count == end_count
|
||||
and prev_end_count == end_count
|
||||
and self.tool_call_end_token not in delta_text
|
||||
):
|
||||
if delta_text:
|
||||
return DeltaMessage(content=delta_text)
|
||||
return None
|
||||
|
||||
# Case 2: Starting a new tool call
|
||||
if start_count > prev_start_count and start_count > end_count:
|
||||
self.current_tool_id += 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
self.prev_tool_call_arr.append({})
|
||||
logger.debug("Starting new tool call %d", self.current_tool_id)
|
||||
# Don't return yet — fall through to try parsing if there's
|
||||
# content after <|tool_call> in this same delta
|
||||
# (but usually it's just the token itself, so return None)
|
||||
if len(delta_text) <= len(self.tool_call_start_token):
|
||||
return None
|
||||
|
||||
# Case 3: Tool call just ended
|
||||
if end_count > prev_end_count:
|
||||
return self._handle_tool_call_end(current_text)
|
||||
|
||||
# Case 4: In the middle of a tool call — parse partial content
|
||||
if start_count > end_count:
|
||||
return self._handle_tool_call_middle(current_text)
|
||||
|
||||
# Default: generate text outside tool calls
|
||||
if delta_text:
|
||||
text = delta_text.replace(self.tool_call_start_token, "")
|
||||
text = text.replace(self.tool_call_end_token, "")
|
||||
if text:
|
||||
return DeltaMessage(content=text)
|
||||
return None
|
||||
|
||||
def _extract_partial_call(self, current_text: str) -> tuple[str | None, str]:
|
||||
"""Extract function name and raw argument string from partial text.
|
||||
|
||||
Returns (func_name, raw_args_str) or (None, "") if not parseable yet.
|
||||
"""
|
||||
# Get the text after the last <|tool_call> token
|
||||
last_start = current_text.rfind(self.tool_call_start_token)
|
||||
if last_start == -1:
|
||||
return None, ""
|
||||
|
||||
partial_call = current_text[last_start + len(self.tool_call_start_token) :]
|
||||
|
||||
# Strip end token if present
|
||||
if self.tool_call_end_token in partial_call:
|
||||
partial_call = partial_call.split(self.tool_call_end_token)[0]
|
||||
|
||||
# Expect "call:name{args...}" or "call:name{args...}"
|
||||
if not partial_call.startswith("call:"):
|
||||
return None, ""
|
||||
|
||||
func_part = partial_call[5:] # skip "call:"
|
||||
|
||||
if "{" not in func_part:
|
||||
# Still accumulating function name, not ready yet
|
||||
return None, ""
|
||||
|
||||
func_name, _, args_part = func_part.partition("{")
|
||||
func_name = func_name.strip()
|
||||
|
||||
# Strip trailing '}' if present (Gemma4 structural brace)
|
||||
if args_part.endswith("}"):
|
||||
args_part = args_part[:-1]
|
||||
|
||||
return func_name, args_part
|
||||
|
||||
def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None:
|
||||
"""Handle streaming when we're inside an active tool call.
|
||||
|
||||
Accumulates the raw Gemma4 arguments, parses them into JSON, and
|
||||
diffs against the previously-streamed JSON to emit only the new
|
||||
fragment.
|
||||
"""
|
||||
func_name, args_part = self._extract_partial_call(current_text)
|
||||
|
||||
if func_name is None:
|
||||
return None
|
||||
|
||||
# Step 1: Send function name (once)
|
||||
if not self.current_tool_name_sent and func_name:
|
||||
self.current_tool_name_sent = True
|
||||
self.prev_tool_call_arr[self.current_tool_id] = {
|
||||
"name": func_name,
|
||||
"arguments": {},
|
||||
}
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
type="function",
|
||||
id=make_tool_call_id(),
|
||||
function=DeltaFunctionCall(
|
||||
name=func_name,
|
||||
arguments="",
|
||||
).model_dump(exclude_none=True),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Step 2: Parse and diff arguments
|
||||
if self.current_tool_name_sent and args_part:
|
||||
return self._emit_argument_diff(args_part)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None:
|
||||
"""Handle streaming when a tool call has just completed.
|
||||
|
||||
Performs a final parse of the complete tool call and flushes
|
||||
any remaining un-streamed argument fragments.
|
||||
"""
|
||||
if self.current_tool_id < 0 or self.current_tool_id >= len(
|
||||
self.prev_tool_call_arr
|
||||
):
|
||||
logger.debug(
|
||||
"Tool call end detected but no active tool call (current_tool_id=%d)",
|
||||
self.current_tool_id,
|
||||
)
|
||||
return None
|
||||
|
||||
# Parse the complete tool call using regex for accuracy
|
||||
all_matches = self.tool_call_regex.findall(current_text)
|
||||
if self.current_tool_id < len(all_matches):
|
||||
_, args_str = all_matches[self.current_tool_id]
|
||||
final_args = _parse_gemma4_args(args_str)
|
||||
final_args_json = json.dumps(final_args, ensure_ascii=False)
|
||||
|
||||
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
|
||||
if len(final_args_json) > len(prev_streamed):
|
||||
diff = final_args_json[len(prev_streamed) :]
|
||||
self.streamed_args_for_tool[self.current_tool_id] = final_args_json
|
||||
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None:
|
||||
"""Parse raw Gemma4 arguments, convert to JSON, diff, and emit.
|
||||
|
||||
This is the core of the accumulate-then-parse-then-diff strategy:
|
||||
1. Parse ``raw_args_str`` with ``_parse_gemma4_args()``
|
||||
2. Convert to JSON string with ``json.dumps()``
|
||||
3. Withhold trailing closing characters (``"}``) that may move
|
||||
as more tokens arrive
|
||||
4. Diff against previously streamed JSON and emit only new chars
|
||||
|
||||
**Why withholding is necessary:**
|
||||
|
||||
Gemma4's custom format produces *structurally incomplete* JSON
|
||||
during streaming. For example, when ``<|"|>Paris`` arrives
|
||||
without a closing delimiter, ``_parse_gemma4_args`` treats it
|
||||
as a complete value and produces ``{"location": "Paris"}``. But
|
||||
when ``, France<|"|>`` arrives next, the JSON becomes
|
||||
``{"location": "Paris, France"}``. If we had sent the closing
|
||||
``"}`` from the first parse, the concatenated client output
|
||||
would be ``{"location": "Paris"}France"}``, which is garbage.
|
||||
|
||||
The solution: **never send trailing closing chars during
|
||||
streaming**. They get flushed by ``_handle_tool_call_end()``
|
||||
when the ``<tool_call|>`` end marker arrives.
|
||||
|
||||
Args:
|
||||
raw_args_str: The raw Gemma4 argument text accumulated so far
|
||||
(without the surrounding ``{`` ``}``).
|
||||
|
||||
Returns:
|
||||
DeltaMessage with the argument diff, or None if no new content.
|
||||
"""
|
||||
try:
|
||||
current_args = _parse_gemma4_args(raw_args_str)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not parse partial Gemma4 args yet: %s",
|
||||
raw_args_str[:100],
|
||||
)
|
||||
return None
|
||||
|
||||
if not current_args:
|
||||
return None
|
||||
|
||||
current_args_json = json.dumps(current_args, ensure_ascii=False)
|
||||
|
||||
# Withhold trailing closing characters that may shift as more
|
||||
# tokens arrive. Strip trailing '}', '"', and ']' sequences
|
||||
# to get the "safe prefix".
|
||||
safe_json = current_args_json
|
||||
while safe_json and safe_json[-1] in ("}", '"', "]"):
|
||||
safe_json = safe_json[:-1]
|
||||
|
||||
prev_streamed = self.streamed_args_for_tool[self.current_tool_id]
|
||||
|
||||
if not safe_json or safe_json == prev_streamed:
|
||||
return None
|
||||
|
||||
# Use find_common_prefix to handle cases where the value changed
|
||||
# structurally (e.g., a string grew).
|
||||
if prev_streamed:
|
||||
prefix = find_common_prefix(prev_streamed, safe_json)
|
||||
sent_len = len(prev_streamed)
|
||||
prefix_len = len(prefix)
|
||||
|
||||
if prefix_len < sent_len:
|
||||
# Structure changed — we sent too much. Truncate our
|
||||
# tracking to the common prefix and wait for the final
|
||||
# flush in _handle_tool_call_end.
|
||||
self.streamed_args_for_tool[self.current_tool_id] = prefix
|
||||
return None
|
||||
|
||||
# Stream the new stable portion
|
||||
diff = safe_json[sent_len:]
|
||||
else:
|
||||
# First emission
|
||||
diff = safe_json
|
||||
|
||||
if diff:
|
||||
self.streamed_args_for_tool[self.current_tool_id] = safe_json
|
||||
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args
|
||||
|
||||
return DeltaMessage(
|
||||
tool_calls=[
|
||||
DeltaToolCall(
|
||||
index=self.current_tool_id,
|
||||
function=DeltaFunctionCall(arguments=diff).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return None
|
||||
183
vllm/tool_parsers/gemma4_utils.py
Normal file
183
vllm/tool_parsers/gemma4_utils.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
"""Gemma4 tool call parsing utilities for offline inference.
|
||||
|
||||
Standalone functions that parse decoded model text to extract tool calls
|
||||
from Gemma4 models. These are pure-Python utilities with zero heavy
|
||||
dependencies — they work on raw decoded strings from any inference
|
||||
backend (vLLM, HuggingFace, TGI, etc.).
|
||||
|
||||
For the OpenAI-compatible API server tool parser (streaming +
|
||||
non-streaming), see ``vllm.tool_parsers.gemma4_tool_parser``.
|
||||
For thinking/reasoning output parsing, see
|
||||
``vllm.reasoning.gemma4_utils``.
|
||||
|
||||
Usage with vLLM offline inference::
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.tool_parsers.gemma4_utils import (
|
||||
parse_tool_calls,
|
||||
has_tool_response_tag,
|
||||
)
|
||||
|
||||
llm = LLM(model="google/gemma-4-it")
|
||||
outputs = llm.generate(prompt, SamplingParams(...))
|
||||
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
|
||||
|
||||
# Extract tool calls
|
||||
tool_calls = parse_tool_calls(text)
|
||||
for tc in tool_calls:
|
||||
print(f"{tc['name']}({tc['arguments']})")
|
||||
|
||||
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
|
||||
do not need a transformers dependency for output parsing.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import regex as re
|
||||
|
||||
# Tool call delimiter tokens as they appear in decoded text.
|
||||
# Standard format: <|tool_call>call:name{args}<tool_call|>
|
||||
_TOOL_CALL_START_TAG = "<|tool_call>"
|
||||
_TOOL_CALL_END_TAG = "<tool_call|>"
|
||||
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
|
||||
|
||||
# Gemma4 escape token as it appears in decoded text.
|
||||
_ESCAPE_TOKEN = '<|"|>'
|
||||
|
||||
|
||||
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
|
||||
"""Parse tool call arguments from the Gemma4 compact format.
|
||||
|
||||
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
|
||||
to heuristic key-value extraction. Also tolerates the slightly different
|
||||
``key: "value"`` format (space + plain quotes) that some chat templates
|
||||
produce.
|
||||
|
||||
Args:
|
||||
args_str: Raw argument string from inside ``call:name{...}``.
|
||||
|
||||
Returns:
|
||||
Dictionary of argument name → value.
|
||||
"""
|
||||
if not args_str or not args_str.strip():
|
||||
return {}
|
||||
|
||||
# Replace Gemma4 escape tokens with standard quotes.
|
||||
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
|
||||
|
||||
# Try JSON parsing first (handles nested values, arrays, etc.).
|
||||
try:
|
||||
parsed = json.loads("{" + cleaned + "}")
|
||||
# Ensure all values are strings for consistency.
|
||||
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: extract key:"value" pairs (allow optional space after colon).
|
||||
arguments = {}
|
||||
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
|
||||
arguments[key] = value
|
||||
|
||||
if not arguments:
|
||||
# Last resort: extract key:value pairs (unquoted).
|
||||
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
|
||||
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
|
||||
"""Parse tool calls from decoded Gemma4 model output.
|
||||
|
||||
Uses a tiered parsing strategy to handle known output variations in
|
||||
Gemma4 models, which may emit
|
||||
non-standard tool call formats.
|
||||
|
||||
Parsing tiers:
|
||||
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
|
||||
(special token IDs 48/49 in decoded text)
|
||||
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
|
||||
patterns, including ``<call>name{args}`` (fragmented tokens from
|
||||
multimodal inputs)
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...,
|
||||
skip_special_tokens=False)``).
|
||||
strict: If ``True``, only match the standard ``<|tool_call>`` format.
|
||||
If ``False`` (default), also try fallback patterns for
|
||||
known Gemma4 output variations.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each with keys:
|
||||
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
|
||||
- ``"arguments"``: A dict of argument name → value.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.tool_parsers.gemma4_utils import parse_tool_calls
|
||||
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> tool_calls = parse_tool_calls(output)
|
||||
>>> for tc in tool_calls:
|
||||
... print(f"Call: {tc['name']}({tc['arguments']})")
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Tier 1: Standard format with special tokens.
|
||||
# <|tool_call>call:name{args}<tool_call|>
|
||||
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
|
||||
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
|
||||
for match in re.finditer(standard_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
if results or strict:
|
||||
return results
|
||||
|
||||
# Tier 2: Fallback for known Gemma4 output variations.
|
||||
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
|
||||
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
|
||||
for match in re.finditer(fallback_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def has_tool_response_tag(text: str) -> bool:
|
||||
"""Check if model output properly ends with a tool response tag.
|
||||
|
||||
Some Gemma4 models sometimes emit ``<eos>`` instead of
|
||||
``<|tool_response>`` after a tool call. This helper detects
|
||||
whether the model used the proper termination, so callers can
|
||||
decide whether to inject ``<|tool_response>`` into the next prompt.
|
||||
|
||||
Args:
|
||||
text: Decoded model output text.
|
||||
|
||||
Returns:
|
||||
``True`` if the output ends with ``<|tool_response>``
|
||||
(proper behavior), ``False`` otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.tool_parsers.gemma4_utils import has_tool_response_tag
|
||||
>>> if not has_tool_response_tag(model_output):
|
||||
... # Model used <eos> instead — inject <|tool_response> manually
|
||||
... next_prompt = "<|tool_response>" + tool_result
|
||||
"""
|
||||
stripped = text.rstrip()
|
||||
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
|
||||
@@ -448,6 +448,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
|
||||
return getattr(self.hf_text_config, "num_nextn_predict_layers", 1)
|
||||
|
||||
|
||||
class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase):
|
||||
def get_head_size(self) -> int:
|
||||
# Gemma4 uses dual head dimensions: head_dim (sliding attention)
|
||||
# and global_head_dim (full attention). Return the largest so
|
||||
# that attention backends allocate buffers large enough for both.
|
||||
head_dim = getattr(self.hf_text_config, "head_dim", 0)
|
||||
global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0)
|
||||
return max(head_dim, global_head_dim) or super().get_head_size()
|
||||
|
||||
|
||||
# hf_config.model_type -> convertor class
|
||||
MODEL_ARCH_CONFIG_CONVERTORS = {
|
||||
"cohere_asr": CohereAsrModelArchConfigConvertor,
|
||||
@@ -471,4 +481,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = {
|
||||
"ernie_mtp": ErnieMTPModelArchConfigConvertor,
|
||||
"pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor,
|
||||
"longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor,
|
||||
"gemma4": Gemma4ModelArchConfigConvertor,
|
||||
"gemma4_text": Gemma4ModelArchConfigConvertor,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user