diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index 1404d9628..bf5119cf4 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -394,6 +394,22 @@ VLM_TEST_SETTINGS = { vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, patch_hf_runner=model_utils.gemma3_patch_hf_runner, ), + "gemma4": VLMTestInfo( + models=["google/gemma-4-E2B-it"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 + single_image_prompts=IMAGE_ASSETS.prompts( + { + "stop_sign": "What's the content in the center of the image?", + "cherry_blossom": "What is the season?", + } + ), + multi_image_prompt="Describe the two images in detail.", + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForImageTextToText, + vllm_runner_kwargs={"limit_mm_per_prompt": {"image": 4}}, + ), "granite_vision": VLMTestInfo( models=["ibm-granite/granite-vision-3.3-2b"], test_type=(VLMTestType.IMAGE), diff --git a/tests/models/multimodal/processing/test_gemma4.py b/tests/models/multimodal/processing/test_gemma4.py new file mode 100644 index 000000000..808fab6a0 --- /dev/null +++ b/tests/models/multimodal/processing/test_gemma4.py @@ -0,0 +1,44 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from vllm.multimodal import MULTIMODAL_REGISTRY + +from ....conftest import ImageTestAssets +from ...utils import build_model_context + +# TODO: to be updated to "google/gemma-4-e2b-it" once the models are available +GEMMA4_MODEL_ID = "google/gemma-4-E2B-it" + + +@pytest.mark.parametrize("model_id", [GEMMA4_MODEL_ID]) +def test_limit_mm_per_prompt( + image_assets: ImageTestAssets, + model_id: str, +): + """Test that limit_mm_per_prompt accurately restricts multiple images.""" + # We only allow 1 image + ctx = build_model_context( + model_id, + mm_processor_kwargs={}, + limit_mm_per_prompt={"image": 1}, + ) + processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config) + + # Provide 2 images in the prompt + prompt = "" + # image_assets usually has multiple images + images = [asset.pil_image for asset in image_assets][:2] + if len(images) < 2: + images = [images[0], images[0]] + + mm_data = {"image": images} + + # Expect ValueError when exceeding limit + with pytest.raises(ValueError, match="At most 1 image"): + processor( + prompt, + mm_items=processor.info.parse_mm_data(mm_data), + hf_processor_mm_kwargs={}, + ) diff --git a/tests/models/registry.py b/tests/models/registry.py index 3d4ecf8c2..d2b561d60 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -277,6 +277,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "google/gemma-2-9b", extras={"tiny": "google/gemma-2-2b-it"} ), "Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"), + "Gemma4ForCausalLM": _HfExamplesInfo( + "google/gemma-4-E2B-it", + min_transformers_version="5.0.0", + ), "Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it"), "GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"), "Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"), @@ -805,6 +809,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { ), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), + "Gemma4ForConditionalGeneration": _HfExamplesInfo( + "google/gemma-4-E2B-it", + min_transformers_version="5.5.0", + ), "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"), "GlmAsrForConditionalGeneration": _HfExamplesInfo( "zai-org/GLM-ASR-Nano-2512", diff --git a/tests/reasoning/test_gemma4_reasoning_parser.py b/tests/reasoning/test_gemma4_reasoning_parser.py new file mode 100644 index 000000000..cdda7dea5 --- /dev/null +++ b/tests/reasoning/test_gemma4_reasoning_parser.py @@ -0,0 +1,196 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest + +from tests.reasoning.utils import run_reasoning_extraction +from vllm.reasoning import ReasoningParser, ReasoningParserManager + +# Using mistral tokenizer as a generic mock since the actual model is not on HF +from vllm.tokenizers.registry import get_tokenizer + +parser_name = "gemma4" + + +@pytest.fixture(scope="module") +def generic_tokenizer(): + return get_tokenizer("google/gemma-4-E2B-it") + + +INVALID_SIMPLE_NONSTREAMING = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +INVALID_SIMPLE_STREAMING = { + "output": "This is a reasoning sectionThis is the rest", + "reasoning": None, + "content": "This is a reasoning sectionThis is the rest", + "is_reasoning_end": True, +} +INVALID_COMPLETE_NONSTREAMING = { + "output": "This is a reasoning section", + "reasoning": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +INVALID_COMPLETE_STREAMING = { + "output": "This is a reasoning section", + "reasoning": None, + "content": "This is a reasoning section", + "is_reasoning_end": True, +} +NO_CONTENT = { + "output": "<|channel>This is reasoning", + "reasoning": "This is reasoning", + "content": None, + "is_reasoning_end": False, +} +NO_REASONING = { + "output": "This is content", + "reasoning": None, + "content": "This is content", + "is_reasoning_end": False, +} +REASONING_WITH_CHANNEL = { + "output": "<|channel>This is a reasoning sectionThis is the rest", + "reasoning": "This is a reasoning section", + "content": "This is the rest", + "is_reasoning_end": True, +} +COMPLETE_REASONING_WITH_CHANNEL = { + "output": "<|channel>This is a reasoning section", + "reasoning": "This is a reasoning section", + "content": None, + "is_reasoning_end": True, +} +MULTIPLE_LINES_WITH_CHANNEL = { + "output": "<|channel>This\nThatThis is the rest\nThat", + "reasoning": "This\nThat", + "content": "This is the rest\nThat", + "is_reasoning_end": True, +} +CHANNEL_NO_END = { + "output": "<|channel>This is a reasoning section", + "reasoning": "This is a reasoning section", + "content": None, + "is_reasoning_end": False, +} +EMPTY = { + "output": "", + "reasoning": None, + "content": "", + "is_reasoning_end": False, +} +NEW_LINE_NONSTREAMING = { + "output": ( + "Before\n<|channel>This is a reasoning section\nThis is the rest" + ), + "reasoning": "This is a reasoning section", + "content": "\nThis is the rest", + "is_reasoning_end": True, +} +NEW_LINE_STREAMING = { + "output": ( + "Before\n<|channel>This is a reasoning section\nThis is the rest" + ), + "reasoning": "This is a reasoning section", + "content": "Before\n\nThis is the rest", + "is_reasoning_end": True, +} + +TEST_CASES = [ + pytest.param(False, INVALID_SIMPLE_NONSTREAMING, id="invalid_simple"), + pytest.param(True, INVALID_SIMPLE_STREAMING, id="invalid_simple_streaming"), + pytest.param(False, INVALID_COMPLETE_NONSTREAMING, id="invalid_complete"), + pytest.param(True, INVALID_COMPLETE_STREAMING, id="invalid_complete_streaming"), + pytest.param(False, NO_CONTENT, id="no_content"), + pytest.param(False, NO_REASONING, id="no_reasoning"), + pytest.param(False, REASONING_WITH_CHANNEL, id="reasoning"), + pytest.param(True, REASONING_WITH_CHANNEL, id="reasoning_streaming"), + pytest.param(False, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning"), + pytest.param( + True, COMPLETE_REASONING_WITH_CHANNEL, id="complete_reasoning_streaming" + ), + pytest.param(False, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines"), + pytest.param(True, MULTIPLE_LINES_WITH_CHANNEL, id="multiple_lines_streaming"), + pytest.param(False, CHANNEL_NO_END, id="no_end"), + pytest.param(True, CHANNEL_NO_END, id="no_end_streaming"), + pytest.param(False, EMPTY, id="empty"), + pytest.param(False, NEW_LINE_NONSTREAMING, id="new_line"), + pytest.param(True, NEW_LINE_STREAMING, id="new_line_streaming"), +] + + +@pytest.mark.parametrize("streaming, param_dict", TEST_CASES) +def test_gemma4_reasoning( + streaming: bool, + param_dict: dict, + generic_tokenizer, +): + output = param_dict["output"] + + # Resolve token IDs dynamically from the real tokenizer + vocab = generic_tokenizer.get_vocab() + start_token_id = vocab["<|channel>"] + end_token_id = vocab[""] + + index_start = output.find("<|channel>") + len_start = len("<|channel>") + index_end = output.find("") + len_end = len("") + + output_tokens = [] + + def _encode(text: str) -> list[int]: + if not text: + return [] + # Handle both raw transformers and vLLM wrappers + enc = getattr(generic_tokenizer, "tokenizer", generic_tokenizer) + try: + return enc.encode(text, add_special_tokens=False) + except TypeError: + return enc.encode(text) + + if index_start != -1: + output_before = output[:index_start] + output_tokens += _encode(output_before) + output_tokens += [start_token_id] + + if index_end != -1: + output_middle = output[index_start + len_start : index_end] + output_after = output[index_end + len_end :] + output_tokens += _encode(output_middle) + output_tokens += [end_token_id] + output_tokens += _encode(output_after) + else: + output_middle = output[index_start + len_start :] + output_tokens += _encode(output_middle) + elif index_end != -1: + output_before = output[:index_end] + output_after = output[index_end + len_end :] + output_tokens += _encode(output_before) + output_tokens += [end_token_id] + output_tokens += _encode(output_after) + else: + output_tokens += _encode(output) + + parser: ReasoningParser = ReasoningParserManager.get_reasoning_parser(parser_name)( + generic_tokenizer + ) + + # We use the generic run_reasoning_extraction from utils + # Use decode per token to get standard spaces instead of + # SentencePiece space characters + output_token_strings = [generic_tokenizer.decode([t]) for t in output_tokens] + reasoning, content = run_reasoning_extraction( + parser, output_token_strings, streaming=streaming + ) + + assert reasoning == param_dict["reasoning"] + assert content == param_dict["content"] + + # Test is_reasoning_end + is_reasoning_end = parser.is_reasoning_end(output_tokens) + assert is_reasoning_end == param_dict["is_reasoning_end"] diff --git a/tests/tool_parsers/test_gemma4_tool_parser.py b/tests/tool_parsers/test_gemma4_tool_parser.py new file mode 100644 index 000000000..80cf70d6c --- /dev/null +++ b/tests/tool_parsers/test_gemma4_tool_parser.py @@ -0,0 +1,504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest +from vllm.tool_parsers.gemma4_tool_parser import ( + TOOL_CALL_END, + TOOL_CALL_START, + Gemma4ToolParser, + _parse_gemma4_args, + _parse_gemma4_array, +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_tokenizer(): + tokenizer = MagicMock() + tokenizer.encode.return_value = [1, 2, 3] + # Include the tool call start token in the vocab for the parser + tokenizer.get_vocab.return_value = {TOOL_CALL_START: 48, TOOL_CALL_END: 49} + return tokenizer + + +@pytest.fixture +def parser(mock_tokenizer): + return Gemma4ToolParser(mock_tokenizer) + + +@pytest.fixture +def mock_request(): + request = MagicMock(spec=ChatCompletionRequest) + request.tools = [] + request.tool_choice = "auto" + return request + + +# --------------------------------------------------------------------------- +# Unit tests for _parse_gemma4_args (shared parser logic) +# --------------------------------------------------------------------------- + + +class TestParseGemma4Args: + def test_empty_string(self): + assert _parse_gemma4_args("") == {} + + def test_whitespace_only(self): + assert _parse_gemma4_args(" ") == {} + + def test_single_string_value(self): + result = _parse_gemma4_args('location:<|"|>Paris<|"|>') + assert result == {"location": "Paris"} + + def test_string_value_with_comma(self): + result = _parse_gemma4_args('location:<|"|>Paris, France<|"|>') + assert result == {"location": "Paris, France"} + + def test_multiple_string_values(self): + result = _parse_gemma4_args( + 'location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|>' + ) + assert result == {"location": "San Francisco", "unit": "celsius"} + + def test_integer_value(self): + result = _parse_gemma4_args("count:42") + assert result == {"count": 42} + + def test_float_value(self): + result = _parse_gemma4_args("score:3.14") + assert result == {"score": 3.14} + + def test_boolean_true(self): + result = _parse_gemma4_args("flag:true") + assert result == {"flag": True} + + def test_boolean_false(self): + result = _parse_gemma4_args("flag:false") + assert result == {"flag": False} + + def test_mixed_types(self): + result = _parse_gemma4_args( + 'name:<|"|>test<|"|>,count:42,active:true,score:3.14' + ) + assert result == { + "name": "test", + "count": 42, + "active": True, + "score": 3.14, + } + + def test_nested_object(self): + result = _parse_gemma4_args('nested:{inner:<|"|>value<|"|>}') + assert result == {"nested": {"inner": "value"}} + + def test_array_of_strings(self): + result = _parse_gemma4_args('items:[<|"|>a<|"|>,<|"|>b<|"|>]') + assert result == {"items": ["a", "b"]} + + def test_unterminated_string(self): + """Unterminated strings should take everything after the delimiter.""" + result = _parse_gemma4_args('key:<|"|>unterminated') + assert result == {"key": "unterminated"} + + def test_empty_value(self): + """Key with no value after colon.""" + result = _parse_gemma4_args("key:") + assert result == {"key": ""} + + +class TestParseGemma4Array: + def test_string_array(self): + result = _parse_gemma4_array('<|"|>a<|"|>,<|"|>b<|"|>') + assert result == ["a", "b"] + + def test_empty_array(self): + result = _parse_gemma4_array("") + assert result == [] + + def test_bare_values(self): + result = _parse_gemma4_array("42,true,3.14") + assert result == [42, True, 3.14] + + +# --------------------------------------------------------------------------- +# Non-streaming extraction tests +# --------------------------------------------------------------------------- + + +class TestExtractToolCalls: + def test_no_tool_calls(self, parser, mock_request): + model_output = "Hello, how can I help you today?" + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is False + assert result.tool_calls == [] + assert result.content == model_output + + def test_single_tool_call(self, parser, mock_request): + model_output = ( + '<|tool_call>call:get_weather{location:<|"|>London<|"|>}' + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "get_weather" + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {"location": "London"} + + def test_multiple_arguments(self, parser, mock_request): + model_output = ( + "<|tool_call>call:get_weather{" + 'location:<|"|>San Francisco<|"|>,' + 'unit:<|"|>celsius<|"|>}' + "" + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "get_weather" + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {"location": "San Francisco", "unit": "celsius"} + + def test_text_before_tool_call(self, parser, mock_request): + model_output = ( + "Let me check the weather for you. " + '<|tool_call>call:get_weather{location:<|"|>Paris<|"|>}' + "" + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert result.content == "Let me check the weather for you." + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "get_weather" + + def test_multiple_tool_calls(self, parser, mock_request): + model_output = ( + '<|tool_call>call:get_weather{location:<|"|>London<|"|>}' + "" + '<|tool_call>call:get_time{location:<|"|>London<|"|>}' + "" + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert len(result.tool_calls) == 2 + assert result.tool_calls[0].function.name == "get_weather" + assert result.tool_calls[1].function.name == "get_time" + + def test_nested_arguments(self, parser, mock_request): + model_output = ( + "<|tool_call>call:complex_function{" + 'nested:{inner:<|"|>value<|"|>},' + 'list:[<|"|>a<|"|>,<|"|>b<|"|>]}' + "" + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "complex_function" + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {"nested": {"inner": "value"}, "list": ["a", "b"]} + + def test_tool_call_with_number_and_boolean(self, parser, mock_request): + model_output = ( + "<|tool_call>call:set_status{" + "is_active:true," + "count:42," + "score:3.14}" + "" + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "set_status" + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {"is_active": True, "count": 42, "score": 3.14} + + def test_incomplete_tool_call(self, parser, mock_request): + model_output = '<|tool_call>call:get_weather{location:<|"|>London' + result = parser.extract_tool_calls(model_output, mock_request) + + # Incomplete — no end marker, regex won't match + assert result.tools_called is False + assert result.content == model_output + + def test_hyphenated_function_name(self, parser, mock_request): + """Ensure function names with hyphens are parsed correctly.""" + model_output = ( + '<|tool_call>call:get-weather{location:<|"|>London<|"|>}' + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert result.tool_calls[0].function.name == "get-weather" + + def test_dotted_function_name(self, parser, mock_request): + """Ensure function names with dots are parsed correctly.""" + model_output = ( + '<|tool_call>call:weather.get{location:<|"|>London<|"|>}' + ) + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert result.tool_calls[0].function.name == "weather.get" + + def test_no_arguments(self, parser, mock_request): + """Tool calls with empty arguments.""" + model_output = "<|tool_call>call:get_status{}" + result = parser.extract_tool_calls(model_output, mock_request) + + assert result.tools_called is True + assert result.tool_calls[0].function.name == "get_status" + args = json.loads(result.tool_calls[0].function.arguments) + assert args == {} + + +# --------------------------------------------------------------------------- +# Streaming extraction tests +# --------------------------------------------------------------------------- + + +class TestStreamingExtraction: + """Tests for the streaming tool call extraction. + + These simulate the token-by-token streaming that vLLM performs, + feeding incremental text to extract_tool_calls_streaming() and + verifying that the accumulated argument deltas form valid JSON. + """ + + def _simulate_streaming( + self, parser: Gemma4ToolParser, mock_request: Any, chunks: list[str] + ) -> list[tuple[Any, str]]: + """Feed chunks through the streaming parser and collect results. + + Returns a list of (delta_message, accumulated_text) tuples. + """ + results: list[tuple[Any, str]] = [] + previous_text: str = "" + previous_token_ids: list[int] = [] + + for chunk in chunks: + current_text = previous_text + chunk + # Use token ID 48 for tool_call start, 49 for end, 0 otherwise + delta_token_ids: list[int] = [] + if TOOL_CALL_START in chunk: + delta_token_ids.append(48) + elif TOOL_CALL_END in chunk: + delta_token_ids.append(49) + else: + delta_token_ids.append(0) + + current_token_ids = previous_token_ids + delta_token_ids + + delta = parser.extract_tool_calls_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=chunk, + previous_token_ids=tuple(previous_token_ids), + current_token_ids=tuple(current_token_ids), + delta_token_ids=tuple(delta_token_ids), + request=mock_request, + ) + results.append((delta, current_text)) + previous_text = current_text + previous_token_ids = list(current_token_ids) + + return results + + def _collect_arguments(self, results): + """Collect all argument deltas from streaming results into one string.""" + args_text = "" + for delta, _ in results: + if delta and delta.tool_calls: + for tc in delta.tool_calls: + func = tc.function if isinstance(tc.function, dict) else tc.function + if isinstance(func, dict): + arg = func.get("arguments", "") + else: + arg = getattr(func, "arguments", "") or "" + if arg: + args_text += arg + return args_text + + def _collect_function_name(self, results): + """Extract the function name from streaming results.""" + for delta, _ in results: + if delta and delta.tool_calls: + for tc in delta.tool_calls: + func = tc.function if isinstance(tc.function, dict) else tc.function + if isinstance(func, dict): + name = func.get("name") + else: + name = getattr(func, "name", None) + if name: + return name + return None + + def test_basic_streaming_single_tool(self, parser, mock_request): + """Simulate the exact streaming scenario from the bug report. + + Model generates: + <|tool_call>call:get_weather{location:<|"|>Paris, France<|"|>} + + Expected: arguments should be valid JSON {"location": "Paris, France"} + """ + chunks = [ + "<|tool_call>", + "call:get_weather{", + 'location:<|"|>Paris', + ", France", + '<|"|>}', + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + + # Verify function name + name = self._collect_function_name(results) + assert name == "get_weather", f"Expected 'get_weather', got '{name}'" + + # Verify arguments form valid JSON + args_text = self._collect_arguments(results) + assert args_text, "No arguments were streamed" + parsed_args = json.loads(args_text) + assert parsed_args == {"location": "Paris, France"} + + def test_streaming_multi_arg(self, parser, mock_request): + """Streaming with multiple arguments.""" + chunks = [ + "<|tool_call>", + "call:get_weather{", + 'location:<|"|>Tokyo<|"|>,', + 'unit:<|"|>celsius<|"|>}', + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + + name = self._collect_function_name(results) + assert name == "get_weather" + + args_text = self._collect_arguments(results) + assert args_text + parsed_args = json.loads(args_text) + assert parsed_args == {"location": "Tokyo", "unit": "celsius"} + + def test_streaming_no_extra_brace(self, parser, mock_request): + """Verify the closing } is NOT leaked into arguments (Bug #2).""" + chunks = [ + "<|tool_call>", + "call:get_weather{", + 'location:<|"|>London<|"|>}', + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + args_text = self._collect_arguments(results) + assert args_text + + # The args text must be valid JSON (no extra }) + parsed = json.loads(args_text) + assert parsed == {"location": "London"} + + # Specifically assert no double-brace + assert args_text.count("}") <= 1, ( + f"Arguments contain extra closing brace: {args_text!r}" + ) + + def test_streaming_no_unquoted_keys(self, parser, mock_request): + """Verify keys are properly quoted in JSON (Bug #1).""" + chunks = [ + "<|tool_call>", + "call:get_weather{", + 'location:<|"|>Paris<|"|>}', + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + args_text = self._collect_arguments(results) + + # Must start with { and contain quoted key + assert args_text.lstrip().startswith("{"), ( + f"Arguments don't start with '{{': {args_text!r}" + ) + assert '"location"' in args_text, ( + f"Key 'location' not properly quoted: {args_text!r}" + ) + + def test_streaming_name_no_call_prefix(self, parser, mock_request): + """Verify function name has no 'call:' prefix.""" + chunks = [ + "<|tool_call>", + "call:get_weather{", + 'location:<|"|>Paris<|"|>}', + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + name = self._collect_function_name(results) + assert name == "get_weather" + assert not name.startswith("call:"), f"Name has 'call:' prefix: {name!r}" + + def test_streaming_text_before_tool_call(self, parser, mock_request): + """Text before tool call should be emitted as content.""" + chunks = [ + "Let me check ", + "the weather. ", + "<|tool_call>", + "call:get_weather{", + 'location:<|"|>London<|"|>}', + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + + # First chunks should be content + content_parts = [] + for delta, _ in results: + if delta and delta.content: + content_parts.append(delta.content) + + assert "".join(content_parts).strip().startswith("Let me check") + + def test_streaming_numeric_args(self, parser, mock_request): + """Streaming with numeric and boolean argument values.""" + chunks = [ + "<|tool_call>", + "call:set_config{", + "count:42,", + "active:true}", + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + args_text = self._collect_arguments(results) + if args_text: + parsed_args = json.loads(args_text) + assert parsed_args["count"] == 42 + assert parsed_args["active"] is True + + def test_streaming_empty_args(self, parser, mock_request): + """Tool call with no arguments.""" + chunks = [ + "<|tool_call>", + "call:get_status{}", + "", + ] + + results = self._simulate_streaming(parser, mock_request, chunks) + name = self._collect_function_name(results) + assert name == "get_status" diff --git a/vllm/model_executor/layers/rotary_embedding/__init__.py b/vllm/model_executor/layers/rotary_embedding/__init__.py index 9ad7c9cda..28157daab 100644 --- a/vllm/model_executor/layers/rotary_embedding/__init__.py +++ b/vllm/model_executor/layers/rotary_embedding/__init__.py @@ -12,6 +12,7 @@ from .dual_chunk_rope import DualChunkRotaryEmbedding from .dynamic_ntk_alpha_rope import DynamicNTKAlphaRotaryEmbedding from .dynamic_ntk_scaling_rope import DynamicNTKScalingRotaryEmbedding from .fope import FourierRotaryEmbedding +from .gemma4_rope import Gemma4RotaryEmbedding from .linear_scaling_rope import LinearScalingRotaryEmbedding from .llama3_rope import Llama3RotaryEmbedding from .llama4_vision_rope import Llama4VisionRotaryEmbedding @@ -134,6 +135,17 @@ def get_rope( is_neox_style, dtype, ) + elif scaling_type == "proportional": + # Proportional RoPE is used by Gemma4 for global (full) attention. + # Gemma4 uses a sparse/fractional RoPE with cross-mixing between halves. + rotary_emb = Gemma4RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) elif scaling_type == "llama3": scaling_factor = rope_parameters["factor"] low_freq_factor = rope_parameters["low_freq_factor"] diff --git a/vllm/model_executor/layers/rotary_embedding/gemma4_rope.py b/vllm/model_executor/layers/rotary_embedding/gemma4_rope.py new file mode 100644 index 000000000..48253f469 --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding/gemma4_rope.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Gemma4-specific Rotary Positional Embeddings (proportional scaling). + +Gemma4 uses "proportional" RoPE which computes inv_freq frequencies scaled +by head_dim (not rotary_dim), and zero-pads for non-rotated dimensions when +partial_rotary_factor < 1. The actual rotation uses standard neox-style +rotate_half, matching HF transformers' apply_rotary_pos_emb. +""" + +import torch + +from .base import RotaryEmbedding + + +class Gemma4RotaryEmbedding(RotaryEmbedding): + """Gemma4 proportional RoPE. + + Extends RotaryEmbedding (which provides standard neox-style rotation + via ops.rotary_embedding CUDA kernel) but overrides the inv_freq + computation to match HF's _compute_proportional_rope_parameters: + - Frequency exponents use head_dim (not rotary_dim) as denominator + - Non-rotated dims are zero-padded (cos=1, sin=0 = identity rotation) + + When partial_rotary_factor=1.0 (the default for some variants), ALL dims are + rotated and this is equivalent to standard RotaryEmbedding with + head_dim-scaled frequencies. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: float, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + # Number of rotation angle pairs (from partial_rotary_factor) + self.rope_angles = rotary_dim // 2 + # Non-rotated angle pairs per half + self.nope_angles = (head_size // 2) - self.rope_angles + + # Important: set rotary_dim = head_size so the base class's + # forward_static applies rotation to ALL dims of the cos/sin cache. + # The non-rotated dims will have cos=1, sin=0 (identity) thanks + # to our _compute_inv_freq zero-padding. + super().__init__( + head_size, + head_size, # rotary_dim = head_size (full application) + max_position_embeddings, + base, + is_neox_style, + dtype, + ) + + def _compute_inv_freq(self, base: float) -> torch.Tensor: + """Compute frequencies matching HF proportional RoPE. + + Key difference from base: exponent denominator is head_size (not + rotary_dim), and non-rotated dims are zero-padded. + """ + # HF formula: base ** (arange(0, 2*rope_angles, 2) / head_dim) + freq_exponents = ( + torch.arange(0, 2 * self.rope_angles, 2, dtype=torch.float) / self.head_size + ) + inv_freq = 1.0 / (base**freq_exponents) + + # Zero-pad for non-rotated dims (identity rotation: cos=1, sin=0) + if self.nope_angles > 0: + inv_freq = torch.cat( + [ + inv_freq, + torch.zeros(self.nope_angles, dtype=torch.float), + ] + ) + return inv_freq + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", rope_angles={self.rope_angles}, nope_angles={self.nope_angles}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index a5644a414..c13ea5393 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -14,6 +14,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentio if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig + logger = init_logger(__name__) @@ -57,6 +58,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig): hf_config.is_causal = not hf_config.use_bidirectional_attention +class Gemma4Config(VerifyAndUpdateConfig): + @staticmethod + def verify_and_update_config(vllm_config: "VllmConfig") -> None: + """Force unified attention backend for models with heterogeneous + head dimensions. + + Some Gemma4 variants use different head dimensions for + sliding window (head_dim) vs full attention (global_head_dim) layers. + When global_head_dim > 256, FlashAttention rejects those layers + (head_size <= 256 kernel limit), causing vLLM to select a different + backend for each layer type. This mixed-backend execution produces + numerical divergence and output corruption. + + The fix detects heterogeneous head dimensions from the model config + and forces TRITON_ATTN (which has no head_size ceiling) for all + layers when the user hasn't explicitly chosen a backend. + + TODO: Heterogeneous head_sizes (head_dim != global_head_dim) + require NixlConnector changes to support per-layer KV transfer + with different head dimensions for prefill-decode disaggregation. + """ + hf_text_config = vllm_config.model_config.hf_text_config + head_dim = getattr(hf_text_config, "head_dim", None) + global_head_dim = getattr(hf_text_config, "global_head_dim", None) + + # Only force Triton when head dimensions actually differ AND the + # larger one exceeds FlashAttention's kernel limit (head_size <= 256). + # This avoids unnecessary backend forcing on smaller models where + # the config carries global_head_dim but all layers can still use + # the same FA backend. + max_head_dim = max(head_dim or 0, global_head_dim or 0) + if ( + head_dim is not None + and global_head_dim is not None + and head_dim != global_head_dim + and max_head_dim > 256 + and vllm_config.attention_config.backend is None + ): + from vllm.v1.attention.backends.registry import ( + AttentionBackendEnum, + ) + + vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN + logger.info( + "Gemma4 model has heterogeneous head dimensions " + "(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN " + "backend to prevent mixed-backend numerical divergence.", + head_dim, + global_head_dim, + ) + + class GptOssForCausalLMConfig(VerifyAndUpdateConfig): @staticmethod def verify_and_update_config(vllm_config: "VllmConfig") -> None: @@ -668,6 +721,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501 "FalconMambaForCausalLM": MambaModelConfig, "Gemma3TextModel": Gemma3TextModelConfig, + "Gemma4ForCausalLM": Gemma4Config, + "Gemma4ForConditionalGeneration": Gemma4Config, "GptOssForCausalLM": GptOssForCausalLMConfig, "GteModel": SnowflakeGteNewModelConfig, "GteNewForSequenceClassification": GteNewModelConfig, diff --git a/vllm/model_executor/models/gemma4.py b/vllm/model_executor/models/gemma4.py new file mode 100644 index 000000000..edb533134 --- /dev/null +++ b/vllm/model_executor/models/gemma4.py @@ -0,0 +1,1239 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 The vLLM team. +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Gemma 4 model implementation for vLLM.""" + +from collections.abc import Iterable +from itertools import islice + +import regex as re +import torch +from torch import nn + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import ( + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.sequence import IntermediateTensors + +from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) + +logger = init_logger(__name__) + + +def _get_text_config(config): + """Dereference text_config if config is a nested Gemma4Config. + + Gemma4 checkpoints use architectures=["Gemma4ForConditionalGeneration"] + which yields a Gemma4Config with nested text_config. This function + transparently returns the text config regardless of nesting. + """ + if hasattr(config, "text_config"): + return config.text_config + return config + + +class Gemma4MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_activation: str, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_activation != "gelu_pytorch_tanh": + raise ValueError( + "Gemma4 uses `gelu_pytorch_tanh` as the hidden activation " + "function. Please set `hidden_act` and `hidden_activation` to " + "`gelu_pytorch_tanh`." + ) + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Gemma4Router(nn.Module): + """Router for Gemma4 MoE that preprocesses input before projection. + + Applies RMSNorm (no learned weight), root_size scaling + (hidden_size^{-0.5}), then a learned per-dimension scale before + projecting to expert logits. + + This preprocessing is applied ONLY to the router's input, not to + the expert MLPs' input. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + + # RMSNorm without learned weight — pure normalization only + self.norm = RMSNorm(self.hidden_size, eps=config.rms_norm_eps, has_weight=False) + # Per-dimension learned scale, applied after norm + root_size + self.scale = nn.Parameter(torch.ones(self.hidden_size)) + # Constant 1/sqrt(hidden_size) scaling factor + self.register_buffer( + "root_size", + torch.tensor(self.hidden_size**-0.5), + persistent=False, + ) + # Project to expert logits; replicated across TP for consistent routing + # GateLinear supports bf16 W/A → fp32 output, which is important + # because the topk kernel often needs fp32 for stable routing. + self.proj = GateLinear( + self.hidden_size, + config.num_experts, + bias=False, + out_dtype=torch.float32, + prefix=f"{prefix}.proj", + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Returns raw router logits [T, E].""" + x = self.norm(x) + x = x * self.root_size.to(x.dtype) + x = x * self.scale.to(x.dtype) + router_logits, _ = self.proj(x) + return router_logits + + +class Gemma4MoE(nn.Module): + """Mixture of Experts for Gemma4 using vLLM's FusedMoE. + + Wraps FusedMoE with custom routing. The router projection is + external (Gemma4Router) — this class only handles expert dispatch. + + Gemma4 routing: softmax over ALL experts → top-k → renormalize. + per_expert_scale is folded into routing weights for mathematical + correctness with FusedMoE's fused kernel. + """ + + def __init__( + self, + config, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.num_experts = config.num_experts + + # Per-expert output scale folded into routing weights so that + # FusedMoE's fused kernel computes: Σ_e (expert_e * w_e * scale_e) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_experts)) + + # Gemma4 routing: softmax over ALL experts → top-k → renormalize. + # FusedMoE's built-in fused_topk scopes softmax differently, so + # a custom routing function is needed for numerical correctness. + per_expert_scale = self.per_expert_scale + + def routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + ) -> tuple[torch.Tensor, torch.Tensor]: + _, topk_ids = torch.topk(gating_output, k=topk, dim=-1) + router_probabilities = torch.nn.functional.softmax(gating_output, dim=-1) + indicator = torch.nn.functional.one_hot( + topk_ids, num_classes=gating_output.size(-1) + ).sum(dim=-2) + gate_weights = indicator * router_probabilities + renorm_factor = torch.sum(gate_weights, dim=-1, keepdim=True) + renorm_factor = torch.where(renorm_factor > 0.0, renorm_factor, 1.0) + dispatch_weights = gate_weights / renorm_factor + + topk_weights = dispatch_weights.gather(1, topk_ids) + + # Fold per_expert_scale into routing weights + expert_scales = per_expert_scale[topk_ids].to(topk_weights.dtype) + topk_weights = topk_weights * expert_scales + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + # FusedMoE experts with custom Gemma4 routing + self.experts = FusedMoE( + num_experts=config.num_experts, + top_k=config.top_k_experts, + hidden_size=config.hidden_size, + intermediate_size=getattr( + config, + "moe_intermediate_size", + getattr(config, "expert_intermediate_size", None), + ), + reduce_results=True, + renormalize=True, + quant_config=quant_config, + prefix=f"{prefix}.experts", + custom_routing_function=routing_function, + activation="gelu", + ) + + def forward(self, x: torch.Tensor, router_logits: torch.Tensor) -> torch.Tensor: + return self.experts(x, router_logits) + + +class Gemma4Attention(nn.Module): + def __init__( + self, + config, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int, + use_k_eq_v: bool = False, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + attn_logits_soft_cap: float | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.hidden_size = hidden_size + self.use_k_eq_v = use_k_eq_v + + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0 + else: + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Gemma4 uses scaling=1.0. + # Unlike Gemma2/3, query_pre_attn_scalar is NOT used here; + # Q/K norms with learnable weights handle scaling implicitly. + self.scaling = 1.0 + + # QKVParallelLinear handles GQA correctly for all layer types. + # k_eq_v layers load K weights into both K and V slots via + # _weight_iterator remapping — no structural difference needed. + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # Q/K norms: output = norm(x) * weight (learnable per-head scale) + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + # V norm: no learnable scale (pure normalization only) + self.v_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps, has_weight=False) + + # Determine layer type and sliding window + layer_idx = extract_layer_index(prefix) + layer_type = config.layer_types[layer_idx] + self.is_sliding = layer_type == "sliding_attention" + sliding_window = config.sliding_window if self.is_sliding else None + + # Initialize RoPE based on layer type. + # Gemma4 uses different RoPE parameters for sliding vs full attention. + if layer_type in config.rope_parameters: + # Per-layer-type rope config (dict format). + # rope_parameters already contains the correct + # partial_rotary_factor per layer type (1.0 for full + # attention, 1.0 for sliding). Do NOT override with + # global_partial_rotary_factor — that config key is + # not needed for Gemma4 — config uses per-layer rope_parameters. + rope_parameters = dict(config.rope_parameters[layer_type]) + else: + # Legacy config format fallback. + rope_parameters = dict(config.rope_parameters.copy()) + if self.is_sliding: + rope_parameters["rope_theta"] = getattr( + config, "rope_local_base_freq", 10000.0 + ) + + # KV sharing: layers in the last `num_kv_shared_layers` share KV + # cache with earlier layers of the same type. + kv_sharing_target_layer_name = None + self.is_kv_shared_layer = False + num_kv_shared_layers = getattr(config, "num_kv_shared_layers", 0) + if num_kv_shared_layers > 0: + first_kv_shared_layer_idx = config.num_hidden_layers - num_kv_shared_layers + if layer_idx >= first_kv_shared_layer_idx: + self.is_kv_shared_layer = True + # Find the last non-shared layer of the same attention type + prev_layers = config.layer_types[:first_kv_shared_layer_idx] + current_layer_type = config.layer_types[layer_idx] + kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(current_layer_type) + ) + if kv_shared_layer_index >= 0: + if ".layers." in prefix: + param_name_before_layers = prefix.split(".layers.")[0] + else: + raise ValueError( + "Unexpected prefix format for Gemma4Attention: " + f"'{prefix}'. Expected to contain '.layers.'." + ) + kv_sharing_target_layer_name = ( + f"{param_name_before_layers}.layers." + f"{kv_shared_layer_index}.self_attn.attn" + ) + + self.rotary_emb = get_rope( + self.head_dim, + max_position=max_position_embeddings, + rope_parameters=rope_parameters, + is_neox_style=True, + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + kv_sharing_target_layer_name=kv_sharing_target_layer_name, + prefix=f"{prefix}.attn", + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + # Unified QKV path (works for both k_eq_v and standard layers). + # For k_eq_v, K weights are loaded into both K and V slots of + # qkv_proj, so V == K automatically. + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # Q norm (always applied) + q = q.unflatten(-1, (self.num_heads, self.head_dim)) + q = self.q_norm(q) + q = q.flatten(-2, -1) + + if not self.is_kv_shared_layer: + # Non-shared: apply K norm + RoPE, V norm + k = k.unflatten(-1, (self.num_kv_heads, self.head_dim)) + k = self.k_norm(k) + k = k.flatten(-2, -1) + q, k = self.rotary_emb(positions, q, k) + + v = v.unflatten(-1, (self.num_kv_heads, self.head_dim)) + v = self.v_norm(v) + v = v.flatten(-2, -1) + else: + # Shared: only apply RoPE to Q + q = self.rotary_emb(positions, q, k)[0] + + attn_output = self.attn(q, k, v) + output, _ = self.o_proj(attn_output) + + return output + + +class Gemma4DecoderLayer(nn.Module): + def __init__( + self, + config, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + + layer_idx = extract_layer_index(prefix) + self.layer_idx = layer_idx + + # Gemma4 uses different head dimensions for sliding vs full attention + layer_type = config.layer_types[layer_idx] + self.is_full_attention = layer_type == "full_attention" + if self.is_full_attention: + head_dim = getattr(config, "global_head_dim", config.head_dim) + else: + head_dim = config.head_dim + + # Determine if this full-attention layer uses k_eq_v + # (laptop variant: no v_proj, K reused as V on full attention layers) + use_k_eq_v = self.is_full_attention and getattr( + config, "attention_k_eq_v", False + ) + + # For k_eq_v full-attention layers, use num_global_key_value_heads + # as the KV head count when k_eq_v is enabled. + if use_k_eq_v: + num_kv_heads = getattr( + config, "num_global_key_value_heads", config.num_key_value_heads + ) + else: + num_kv_heads = config.num_key_value_heads + + self.self_attn = Gemma4Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + max_position_embeddings=config.max_position_embeddings, + use_k_eq_v=use_k_eq_v, + cache_config=cache_config, + quant_config=quant_config, + attn_logits_soft_cap=getattr(config, "attn_logit_softcapping", None), + prefix=f"{prefix}.self_attn", + ) + + # Compute per-layer intermediate_size from config. + # When use_double_wide_mlp is set, intermediate_size doubles for + # KV-shared layers (layers >= first_kv_shared_layer_idx). + first_kv_shared_layer_idx = config.num_hidden_layers - getattr( + config, "num_kv_shared_layers", 0 + ) + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = ( + getattr(config, "use_double_wide_mlp", False) and is_kv_shared_layer + ) + layer_intermediate_size = config.intermediate_size * ( + 2 if use_double_wide_mlp else 1 + ) + + self.mlp = Gemma4MLP( + hidden_size=self.hidden_size, + intermediate_size=layer_intermediate_size, + hidden_activation=config.hidden_activation, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + # Layer norms: output = norm(x) * weight + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + # MoE (Mixture of Experts) — router + expert block parallel to MLP + self.enable_moe_block = getattr(config, "enable_moe_block", False) or getattr( + config, "use_second_mlp_block", False + ) + if self.enable_moe_block: + self.router = Gemma4Router( + config, + quant_config=quant_config, + prefix=f"{prefix}.router", + ) + self.moe = Gemma4MoE( + config, + quant_config=quant_config, + prefix=f"{prefix}.moe", + ) + self.post_feedforward_layernorm_1 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.pre_feedforward_layernorm_2 = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.router = None + self.moe = None + self.post_feedforward_layernorm_1 = None + self.post_feedforward_layernorm_2 = None + self.pre_feedforward_layernorm_2 = None + + # Per-Layer Embedding (PLE) components — present in each decoder layer + if ( + self.hidden_size_per_layer_input is not None + and self.hidden_size_per_layer_input > 0 + ): + # Gate: projects hidden_states → per-layer dim for gating + self.per_layer_input_gate = ReplicatedLinear( + self.hidden_size, + self.hidden_size_per_layer_input, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.per_layer_input_gate", + return_bias=False, + ) + # Projection: projects gated per-layer input back → hidden size + self.per_layer_projection = ReplicatedLinear( + self.hidden_size_per_layer_input, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.per_layer_projection", + return_bias=False, + ) + # Post-PLE norm: output = norm(x) * weight + self.post_per_layer_input_norm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.per_layer_input_gate = None + self.per_layer_projection = None + self.post_per_layer_input_norm = None + + # Layer scalar (loaded from checkpoint) — applies to ALL text layers + self.register_buffer("layer_scalar", torch.ones(1)) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + per_layer_input: torch.Tensor | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Gemma4 residual pattern: + # 1. input_norm(x) → attn → post_attn_norm → ADD residual + # 2. pre_ff_norm → mlp → post_ff_norm → ADD residual + residual = hidden_states + + hidden_states = self.input_layernorm(residual) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + **kwargs, + ) + + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = hidden_states + residual + residual = hidden_states + + # MLP runs unconditionally (same inputs for MoE and non-MoE) + hidden_states = self.pre_feedforward_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + if self.enable_moe_block: + hidden_states_1 = self.post_feedforward_layernorm_1(hidden_states) + + # Router and MoE experts see the residual (pre-MLP state), + # matching the HF transformers forward path + router_logits = self.router(residual) + hidden_states_2 = self.pre_feedforward_layernorm_2(residual) + hidden_states_2 = self.moe(hidden_states_2, router_logits) + hidden_states_2 = self.post_feedforward_layernorm_2(hidden_states_2) + + # Combine MLP and MoE outputs + hidden_states = hidden_states_1 + hidden_states_2 + + hidden_states = self.post_feedforward_layernorm(hidden_states) + hidden_states = hidden_states + residual + + # Apply PLE (Per-Layer Embedding) if configured + if per_layer_input is not None and self.per_layer_input_gate is not None: + gate = self.per_layer_input_gate(hidden_states) + gate = torch.nn.functional.gelu(gate, approximate="tanh") + gated_per_layer = gate * per_layer_input + per_layer_contribution = self.per_layer_projection(gated_per_layer) + per_layer_contribution = self.post_per_layer_input_norm( + per_layer_contribution + ) + hidden_states = hidden_states + per_layer_contribution + + # Apply layer scalar for full-attention layers + # Apply per-layer scalar (all text layers) + hidden_states = hidden_states * self.layer_scalar + + return hidden_states, None + + +@support_torch_compile +class Gemma4Model(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = _get_text_config(vllm_config.model_config.hf_config) + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + # PLE config values (default to 0 if not present — disables PLE) + self.hidden_size_per_layer_input = getattr( + config, "hidden_size_per_layer_input", 0 + ) + self.vocab_size_per_layer_input = getattr( + config, "vocab_size_per_layer_input", config.vocab_size + ) + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", + ) + + # Per-Layer Embedding (PLE) components + if ( + self.hidden_size_per_layer_input is not None + and self.hidden_size_per_layer_input > 0 + ): + total_ple_dim = self.hidden_size_per_layer_input * config.num_hidden_layers + self.embed_tokens_per_layer = VocabParallelEmbedding( + self.vocab_size_per_layer_input, + total_ple_dim, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens_per_layer", + ) + # Scaled embedding factor (from config, not hardcoded) + # Register as buffer so it moves to GPU with the model + # and interacts correctly with torch.compile AOT caching. + self.register_buffer( + "embed_scale_per_layer", + torch.tensor(self.hidden_size_per_layer_input**0.5), + persistent=False, + ) + # Projection: hidden_size → total_ple_dim + # ColumnParallelLinear with gather_output=True + self.per_layer_model_projection = ColumnParallelLinear( + config.hidden_size, + total_ple_dim, + bias=False, + gather_output=True, + return_bias=False, + quant_config=quant_config, + prefix=f"{prefix}.per_layer_model_projection", + ) + # PLE projection norm: output = norm(x) * weight + self.per_layer_projection_norm = RMSNorm( + self.hidden_size_per_layer_input, + eps=config.rms_norm_eps, + ) + # Scale factor for combining projection + per_layer_inputs + # Register as buffer so it moves to GPU with the model + # and interacts correctly with torch.compile AOT caching. + self.register_buffer( + "per_layer_input_scale", + torch.rsqrt(torch.tensor(2.0)), + persistent=False, + ) + # Scaled projection: multiply output by hidden_size**-0.5. + # Register as buffer for GPU placement and torch.compile. + self.register_buffer( + "per_layer_projection_scale", + torch.tensor(config.hidden_size**-0.5), + persistent=False, + ) + else: + self.embed_tokens_per_layer = None + self.embed_scale_per_layer = None + self.per_layer_model_projection = None + self.per_layer_projection_norm = None + self.per_layer_input_scale = None + self.per_layer_projection_scale = None + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma4DecoderLayer( + config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + # Final norm: output = norm(x) * weight + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # Embedding scale = sqrt(hidden_size) + # Downcast to model dtype (bfloat16 etc.) for numerical parity + self.register_buffer( + "normalizer", + torch.tensor(config.hidden_size**0.5), + persistent=False, + ) + # Custom factory that includes per_layer_inputs for PLE-enabled PP. + # per_layer_inputs has shape (batch, num_layers, per_layer_dim), + # which differs from the standard (batch, hidden_size) shape, + # so we can't use the default factory. + ple_dim = self.hidden_size_per_layer_input + num_layers = config.num_hidden_layers + hidden_size = config.hidden_size + + def _make_empty_intermediate_tensors( + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: + tensors: dict[str, torch.Tensor] = { + "hidden_states": torch.zeros( + (batch_size, hidden_size), + dtype=dtype, + device=device, + ), + "residual": torch.zeros( + (batch_size, hidden_size), + dtype=dtype, + device=device, + ), + } + if ple_dim and ple_dim > 0: + tensors["per_layer_inputs"] = torch.zeros( + (batch_size, num_layers, ple_dim), + dtype=dtype, + device=device, + ) + return IntermediateTensors(tensors) + + self.make_empty_intermediate_tensors = _make_empty_intermediate_tensors + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) * self.normalizer + + def get_per_layer_inputs(self, input_ids: torch.Tensor) -> torch.Tensor: + """Get per-layer embeddings from embed_tokens_per_layer. + + Returns: + Per-layer embeddings (num_tokens, num_layers, + hidden_size_per_layer_input) + """ + if self.embed_tokens_per_layer is None: + return None + + # Handle out-of-vocab tokens for PLE (vocab_size_per_layer_input may + # be smaller than the main vocab_size). + per_layer_inputs_mask = torch.logical_and( + input_ids >= 0, + input_ids < self.vocab_size_per_layer_input, + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids) + ) + + # Get packed per-layer embeddings: (num_tokens, total_ple_dim) + per_layer_embeds = self.embed_tokens_per_layer(per_layer_inputs_tokens) + + # Apply embed_scale (sqrt of per-layer hidden dim) + per_layer_embeds = per_layer_embeds * self.embed_scale_per_layer + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_embeds = per_layer_embeds.reshape( + *input_ids.shape, + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + return per_layer_embeds + + def project_per_layer_inputs( + self, + inputs_embeds: torch.Tensor, + per_layer_inputs: torch.Tensor | None, + ) -> torch.Tensor: + """Project inputs_embeds and combine with per_layer_inputs. + + Steps: + 1. Project inputs_embeds: hidden_size → total_ple_dim + 2. Scale by hidden_size^{-0.5} + 3. Reshape to (num_tokens, num_layers, per_layer_dim) + 4. Normalize with per_layer_projection_norm + 5. Combine: (projection + per_layer_inputs) * 1/sqrt(2) + """ + if self.per_layer_model_projection is None: + return None + + # Project from hidden_size to total_ple_dim + # Scaled projection: output = linear(input, weight) * scale + per_layer_projection = self.per_layer_model_projection(inputs_embeds) + per_layer_projection = per_layer_projection * self.per_layer_projection_scale + + # Reshape to (num_tokens, num_layers, hidden_size_per_layer_input) + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + + # Normalize + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + + if per_layer_inputs is None: + return per_layer_projection + + # Combine: (projection + per_layer_inputs) * scale + return (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None, + inputs_embeds: torch.Tensor | None = None, + per_layer_inputs: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + # When called from the multimodal wrapper, raw PLE + # embeddings are pre-computed and passed explicitly. + # Project them through per_layer_model_projection. + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_inputs + ) + else: + hidden_states = self.embed_input_ids(input_ids) + # Compute per-layer inputs for PLE + per_layer_embeds = self.get_per_layer_inputs(input_ids) + per_layer_inputs = self.project_per_layer_inputs( + hidden_states, per_layer_embeds + ) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + per_layer_inputs = intermediate_tensors.get("per_layer_inputs") + + for layer_idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): + # Extract the per-layer embedding for this specific layer + if per_layer_inputs is not None: + actual_layer_idx = self.start_layer + layer_idx + layer_per_input = per_layer_inputs[ + :, actual_layer_idx, : + ] # (num_tokens, per_layer_dim) + else: + layer_per_input = None + hidden_states, residual = layer( + positions, + hidden_states, + residual, + per_layer_input=layer_per_input, + **kwargs, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors( + { + "hidden_states": hidden_states, + "residual": residual, + "per_layer_inputs": per_layer_inputs, + } + ) + # Gemma4 incorporates residual into hidden_states directly + # Apply norm without residual fusion when possible. + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # MoE expert weight mapping: checkpoint 3D packed tensors are + # exploded in _weight_iterator to per-expert 2D weights like: + # moe.experts.{id}.gate_proj → FusedMoE w1 (shard of w13) + # moe.experts.{id}.up_proj → FusedMoE w3 (shard of w13) + # moe.experts.{id}.down_proj → FusedMoE w2 + # We build the mapping directly since Gemma4 uses bare param + # names (no .weight suffix) unlike standard MoE checkpoints. + num_experts = getattr(self.config, "num_experts", None) or 0 + expert_params_mapping = [ + # (param_name, weight_name, expert_id, shard_id) + ( + "experts.w13_weight" + if proj_name in ["gate_proj", "up_proj"] + else "experts.w2_weight", + f"experts.{expert_id}.{proj_name}", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, proj_name in [ + ("w1", "gate_proj"), + ("w2", "down_proj"), + ("w3", "up_proj"), + ] + ] + params_dict = dict(self.named_parameters()) + # Include buffers (e.g. layer_scalar) so they can be loaded too + params_dict.update(dict(self.named_buffers())) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + if name.endswith((".k_scale", ".v_scale", ".q_scale", ".prob_scale")): + remapped_name = maybe_remap_kv_scale_name(name, params_dict) + if remapped_name is not None and remapped_name in params_dict: + param = params_dict[remapped_name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(remapped_name) + continue + + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + stacked_name = name.replace(shard_name, param_name) + # k_eq_v layers use separate q_proj/k_proj instead of + # packed qkv_proj. If the stacked param doesn't exist, + # skip this mapping and fall through to direct load. + if stacked_name not in params_dict: + continue + if is_pp_missing_parameter(stacked_name, self): + continue + param = params_dict[stacked_name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(stacked_name) + break + else: + for ( + param_name, + weight_name, + expert_id, + shard_id, + ) in expert_params_mapping: + if weight_name not in name: + continue + moe_name = name.replace(weight_name, param_name) + if moe_name not in params_dict: + continue + if is_pp_missing_parameter(moe_name, self): + continue + param = params_dict[moe_name] + # Expert weights are already in the correct + # orientation for FusedMoE after _weight_iterator: + # gate/up: [I, H] → w1/w3 expects [I, H] + # down: [H, I] → w2 expects [H, I] + assert loaded_weight.dim() == 2, ( + f"Expected 2D expert weight for {weight_name}, " + f"got shape {loaded_weight.shape}" + ) + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name + ".weight", + shard_id=shard_id, + expert_id=expert_id, + ) + loaded_params.add(moe_name) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Gemma4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, MixtureOfExperts): + # Note: qkv_proj packing applies to non-k_eq_v layers (sliding + # attention and full attention without k_eq_v). k_eq_v layers use + # separate q_proj + k_proj without packing. + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + config = _get_text_config(vllm_config.model_config.hf_config) + quant_config = vllm_config.quant_config + + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = Gemma4Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + if config.tie_word_embeddings: + self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens) + + self.logits_processor = LogitsProcessor( + config.vocab_size, + soft_cap=getattr(config, "final_logit_softcapping", None), + ) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors + ) + + # --- MixtureOfExperts protocol --- + self.expert_weights: list[list[torch.Tensor]] = [] + self.moe_layers: list[nn.Module] = [] + example_moe: Gemma4MoE | None = None + + for layer in self.model.layers: + if hasattr(layer, "moe") and isinstance(layer.moe, Gemma4MoE): + example_moe = layer.moe + self.moe_layers.append(layer.moe.experts) + + self.num_moe_layers = len(self.moe_layers) + + if example_moe is not None: + self.num_logical_experts = example_moe.num_experts + self.num_physical_experts = example_moe.num_experts + self.num_local_physical_experts = example_moe.num_experts + self.num_routed_experts = example_moe.num_experts + else: + self.num_logical_experts = 0 + self.num_physical_experts = 0 + self.num_local_physical_experts = 0 + self.num_routed_experts = 0 + + self.num_expert_groups = 1 + self.num_shared_experts = 0 + self.num_redundant_experts = 0 + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor | IntermediateTensors: + hidden_states = self.model( + input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.logits_processor(self.lm_head, hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Checkpoint weight names use "language_model." prefix (from the + # Gemma4ForConditionalGeneration wrapper). Strip it to map to our + # model tree which is just "model.*". + def _weight_iterator(): + use_k_eq_v = getattr(self.config, "attention_k_eq_v", False) + # Build set of k_eq_v layer indices (full_attention layers + # when attention_k_eq_v is enabled). These layers have k_proj + # but no v_proj in checkpoint — we duplicate k_proj as v_proj. + k_eq_v_layer_indices: set[int] = set() + if use_k_eq_v: + for idx, lt in enumerate(self.config.layer_types): + if lt == "full_attention": + k_eq_v_layer_indices.add(idx) + + for name, weight in weights: + # Remap "language_model." → "" to match our model tree. + # Checkpoint: model.language_model.layers.X.* + # Our model: model.layers.X.* + name = name.replace("language_model.", "") + + # Remap new HF checkpoint naming to internal vLLM + # naming: HF moved per_expert_scale to router and + # renamed moe → experts in the MoE block. + name = name.replace( + ".router.per_expert_scale", + ".moe.per_expert_scale", + ) + if ".experts.gate_up_proj" in name: + name = name.replace( + ".experts.gate_up_proj", + ".moe.gate_up_proj", + ) + elif ".experts.down_proj" in name: + name = name.replace( + ".experts.down_proj", + ".moe.down_proj", + ) + + # MoE expert weights: checkpoint stores as 3D packed + # tensors. Explode into per-expert 2D weights for + # FusedMoE weight_loader. + # + # Checkpoint format: + # moe.gate_up_proj: [E, 2*I, H] (fused gate + up) + # moe.down_proj: [E, H, I] + # + # FusedMoE expects per-expert: + # w1 (gate): [I, H] — first half of gate_up + # w3 (up): [I, H] — second half of gate_up + # w2 (down): [H, I] — as-is from checkpoint + # + # No transpose needed: checkpoint orientation already + # matches FusedMoE's expected layout. + if "moe.gate_up_proj" in name and weight.dim() == 3: + num_experts = weight.size(0) + intermediate_size = weight.size(1) // 2 + for expert_id in range(num_experts): + gate_weight = weight[expert_id, :intermediate_size, :] + up_weight = weight[expert_id, intermediate_size:, :] + base = name.replace("moe.", f"moe.experts.{expert_id}.") + yield base.replace("gate_up_proj", "gate_proj"), gate_weight + yield base.replace("gate_up_proj", "up_proj"), up_weight + continue + + if "moe.down_proj" in name and weight.dim() == 3: + num_experts = weight.size(0) + for expert_id in range(num_experts): + expert_name = name.replace("moe.", f"moe.experts.{expert_id}.") + yield expert_name, weight[expert_id] + continue + + # k_eq_v layers: checkpoint has k_proj but no v_proj. + # QKVParallelLinear expects both, so duplicate k_proj + # as v_proj so V gets identical weights to K. + # ONLY for full_attention layers — sliding layers have + # their own real v_proj weights. + if "self_attn.k_proj" in name and k_eq_v_layer_indices: + m = re.search(r"layers\.(\d+)\.", name) + if m and int(m.group(1)) in k_eq_v_layer_indices: + yield name, weight + yield name.replace("k_proj", "v_proj"), weight.clone() + continue + + yield name, weight + + # Skip multimodal weights — handled by the multimodal wrapper. + # Also skip lm_head when weights are tied. + skip = [ + "audio_tower.", + "vision_tower.", + "embed_audio.", + "embed_vision.", + ] + if self.config.tie_word_embeddings: + skip.append("lm_head.") + + loader = AutoWeightsLoader(self, skip_substrs=skip) + return loader.load_weights(_weight_iterator()) diff --git a/vllm/model_executor/models/gemma4_mm.py b/vllm/model_executor/models/gemma4_mm.py new file mode 100644 index 000000000..9aff787a1 --- /dev/null +++ b/vllm/model_executor/models/gemma4_mm.py @@ -0,0 +1,1341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Gemma 4 multimodal model (image + audio + video support). + +Adds vision tower, audio tower, and multimodal embedders on top of the +text-only Gemma4ForCausalLM. The vision/audio encoders are loaded via +AutoModel.from_config and run in eager mode while the language model uses +the vLLM-optimized path. + +Video support: Gemma4 does **not** have a native video tower. Videos are +decomposed into timestamped image frames (up to 32 frames at 70 soft tokens +each) and fed through the same vision tower as regular images. The +processor inserts ``mm:ss`` timestamps between frames so the model can +reason about temporal order. +""" + +import math +import sys +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Any, Literal + +import numpy as np +import torch +from PIL import Image as PILImage +from torch import nn +from transformers import AutoModel, BatchFeature +from transformers.models.gemma4 import ( + Gemma4Config, + Gemma4Processor, + Gemma4VisionConfig, +) +from transformers.models.gemma4.configuration_gemma4 import ( + Gemma4AudioConfig, + Gemma4TextConfig, +) + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions, VideoDummyOptions +from vllm.inputs import MultiModalDataDict +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.models.gemma4 import Gemma4ForCausalLM +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalFieldConfig, + MultiModalKwargsItems, + VideoItem, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + ImageProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import BaseDummyInputsBuilder +from vllm.multimodal.processing.processor import ( + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.sequence import IntermediateTensors +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + init_vllm_registered_model, + maybe_prefix, +) + +logger = init_logger(__name__) + +# Video constants — match transformers Gemma4VideoProcessor defaults. +_VIDEO_MAX_SOFT_TOKENS = 70 # soft tokens per video frame (vs 280 for images) +_VIDEO_MAX_FRAMES = 32 # max sampled frames per video + + +# --------------------------------------------------------------------------- +# Input schema +# --------------------------------------------------------------------------- + + +class Gemma4ImagePixelInputs(TensorSchema): + """ + Pre-patchified image inputs from the Gemma4 image processor. + + Dimensions: + - bn: Batch size * number of images + - np: Number of patches (max_patches = max_soft_tokens * pooling_kernel_size²) + - pp: Patch pixels (patch_size² * 3) + + The HF Gemma4ImageProcessor outputs pixel_values as + (batch, max_patches, patch_pixels) — already patchified with + zero-padding for patches beyond the real image content. + pixel_position_ids provides (x, y) coordinates per patch, + with (-1, -1) for padding patches. + """ + + type: Literal["pixel_values"] = "pixel_values" + pixel_values: Annotated[ + torch.Tensor, + TensorShape("bn", "np", "pp"), + ] + pixel_position_ids: Annotated[ + torch.Tensor, + TensorShape("bn", "np", 2), + ] + + +class Gemma4AudioInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of audios + - s: Sequence length (MEL spectrogram frames) + - f: Number of features (MEL bins) + """ + + type: Literal["audio"] = "audio" + input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")] + input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")] + + +Gemma4ImageInputs = Gemma4ImagePixelInputs + + +class Gemma4VideoInputs(TensorSchema): + """Video frame inputs — same tensor format as image inputs. + + Gemma4 has no separate video tower; video frames are processed + through the vision tower at lower resolution (max_soft_tokens=70). + """ + + type: Literal["pixel_values_videos"] = "pixel_values_videos" + pixel_values_videos: Annotated[ + torch.Tensor, + TensorShape("bn", "np", "pp"), + ] + pixel_position_ids_videos: Annotated[ + torch.Tensor, + TensorShape("bn", "np", 2), + ] + + +# --------------------------------------------------------------------------- +# Processing info +# --------------------------------------------------------------------------- + + +class Gemma4ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self): + return self.ctx.get_hf_config(Gemma4Config) + + def get_default_tok_params(self): + """Gemma4's chat template already embeds a literal ```` token in + the rendered text. If ``add_special_tokens=True`` (the base-class + default), the tokenizer prepends *another* BOS, producing a + ``[2, 2, ...]`` double-BOS sequence that the model was not trained on. + + Setting ``add_special_tokens=False`` here prevents the duplicate and + ensures both ``llm.generate()`` and the chat/completions API behave + correctly. + """ + params = super().get_default_tok_params() + params = params.with_kwargs(add_special_tokens=False) + return params + + def get_hf_processor(self, **kwargs: object) -> Gemma4Processor: + return self.ctx.get_hf_processor( + Gemma4Processor, + **kwargs, + ) + + def validate_num_items(self, modality: str, num_items: int) -> None: + if ( + modality == "audio" + and num_items > 0 + and self.get_hf_config().audio_config is None + ): + model = self.ctx.model_config.model + raise ValueError( + f"Audio input was provided but the model " + f"'{model}' does not have an audio tower. " + f"Audio inference is only supported for Gemma4 " + f"models that include an audio_config " + f"(i.e., models that include an audio_config)." + ) + super().validate_num_items(modality, num_items) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + limits: dict[str, int | None] = {"image": None} + if self.get_hf_config().audio_config is not None: + limits["audio"] = None + limits["video"] = None + return limits + + def get_mm_max_tokens_per_item( + self, seq_len: int, mm_counts: Mapping[str, int] + ) -> Mapping[str, int] | None: + config = self.get_hf_config() + # Upper bound: the pooler outputs default_output_length slots + # per image (280). After padding is stripped the actual count + # is ≤ this value, but vLLM needs the max for memory planning. + tokens_per_image = config.vision_config.default_output_length + tokens: dict[str, int] = {"image": tokens_per_image} + if config.audio_config is not None: + # Audio max tokens from the processor's audio_seq_length. + processor = self.get_hf_processor() + tokens["audio"] = processor.audio_seq_length + # Video: each frame ≤ 70 soft tokens + boi + eoi + ~6 ts tokens. + tokens["video"] = _VIDEO_MAX_FRAMES * (_VIDEO_MAX_SOFT_TOKENS + 2 + 6) + return tokens + + def get_data_parser(self) -> MultiModalDataParser: + config = self.get_hf_config() + kwargs: dict[str, Any] = {"video_needs_metadata": True} + if getattr(config, "audio_config", None) is not None: + processor = self.get_hf_processor() + kwargs["target_sr"] = processor.feature_extractor.sampling_rate + return MultiModalDataParser(**kwargs) + + def _compute_num_soft_tokens( + self, + image_width: int, + image_height: int, + max_soft_tokens: int | None = None, + ) -> int: + """Compute the number of soft tokens the vision tower produces + for an image of the given dimensions, after padding is stripped. + + Args: + max_soft_tokens: Override for the vision config's + ``default_output_length``. When *None*, the value from + the model config is used. + """ + vision_cfg = self.get_hf_config().vision_config + patch_size = vision_cfg.patch_size + pooling_kernel_size = vision_cfg.pooling_kernel_size + + if max_soft_tokens is None: + max_soft_tokens = vision_cfg.default_output_length + + unit = patch_size * pooling_kernel_size + max_patches = max_soft_tokens * pooling_kernel_size**2 + num_patches_orig = (image_height / patch_size) * (image_width / patch_size) + scale = math.sqrt(max_patches / num_patches_orig) + target_h = max(unit, int(math.floor(image_height * scale / unit)) * unit) + target_w = max(unit, int(math.floor(image_width * scale / unit)) * unit) + num_patches = (target_h // patch_size) * (target_w // patch_size) + return num_patches // (pooling_kernel_size**2) + + def get_image_repl( + self, + *, + image_width: int, + image_height: int, + processor: Gemma4Processor | None, + max_soft_tokens: int | None = None, + ) -> PromptUpdateDetails[list[int]]: + """Return the dynamic image token sequence for this image. + + Computes the exact number of soft tokens the vision tower will + produce after stripping padding. + + Args: + max_soft_tokens: Override for the default token budget. + When *None*, falls back to the model config value. + """ + if processor is None: + processor = self.get_hf_processor() + + num_soft = self._compute_num_soft_tokens( + image_width, + image_height, + max_soft_tokens=max_soft_tokens, + ) + config = self.get_hf_config() + token_ids = ( + [config.boi_token_id] + + [processor.image_token_id] * num_soft + + [config.eoi_token_id] + ) + return PromptUpdateDetails.select_token_id(token_ids, processor.image_token_id) + + def get_audio_repl( + self, + *, + audio_len: int, + processor: Gemma4Processor | None, + ) -> PromptUpdateDetails[list[int]]: + """Return the dynamic audio token sequence for this audio. + + Computes the number of soft tokens from the audio waveform + length using ``ceil(duration_ms / audio_ms_per_token)``. + """ + if processor is None: + processor = self.get_hf_processor() + + sampling_rate = processor.feature_extractor.sampling_rate + num_tokens = processor._compute_audio_num_tokens( + torch.zeros(audio_len), sampling_rate + ) + config = self.get_hf_config() + token_ids = ( + [config.boa_token_id] + + [processor.audio_token_id] * num_tokens + + [config.eoa_token_id] + ) + return PromptUpdateDetails.select_token_id(token_ids, processor.audio_token_id) + + def get_video_repl( + self, + *, + timestamps: list[float], + num_soft_tokens_per_frame: list[int], + processor: Gemma4Processor, + ) -> PromptUpdateDetails[list[int]]: + """Build the full token replacement for one video. + + Produces the same interleaved sequence as the HF Gemma4Processor: + mm:ss <|video|>*N mm:ss <|video|>*N ... + """ + tokenizer = self.ctx.get_tokenizer() + config = self.get_hf_config() + + boi_token_id = config.boi_token_id + eoi_token_id = config.eoi_token_id + video_token_id = processor.video_token_id + + all_token_ids: list[int] = [] + for i, (ts, n_tokens) in enumerate(zip(timestamps, num_soft_tokens_per_frame)): + # mm:ss timestamp — matches transformers: int-truncated, + # zero-padded. + minutes = int(ts // 60) + seconds = int(ts % 60) + ts_str = f"{minutes:02d}:{seconds:02d}" + + prefix = f" {ts_str} " if i > 0 else f"{ts_str} " + ts_token_ids = tokenizer.encode(prefix, add_special_tokens=False) + all_token_ids.extend(ts_token_ids) + + all_token_ids.append(boi_token_id) + all_token_ids.extend([video_token_id] * n_tokens) + all_token_ids.append(eoi_token_id) + + return PromptUpdateDetails.select_token_id(all_token_ids, video_token_id) + + +# --------------------------------------------------------------------------- +# Dummy inputs builder +# --------------------------------------------------------------------------- + + +class Gemma4DummyInputsBuilder(BaseDummyInputsBuilder[Gemma4ProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_images = mm_counts.get("image", 0) + num_audios = mm_counts.get("audio", 0) + num_videos = mm_counts.get("video", 0) + processor = self.info.get_hf_processor() + # Use image_token (<|image|>) with tab prefix — this is what the + # Gemma4 chat template inserts per image (\t<|image|>). + # _get_prompt_updates targets image_token and expands it to the + # full_image_sequence. + text = ("\t" + processor.image_token) * num_images + if num_audios > 0 and processor.audio_token: + text += processor.audio_token * num_audios + if num_videos > 0: + text += processor.video_token * num_videos + return text + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + num_images = mm_counts.get("image", 0) + num_audios = mm_counts.get("audio", 0) + num_videos = mm_counts.get("video", 0) + processor = self.info.get_hf_processor() + image_processor = processor.image_processor + # Use processor's configured image size for dummies. + # Gemma4ImageProcessor sets size=None (it uses patch_size / + # max_soft_tokens instead of the standard size dict), so we + # guard against None with `or {}`. + size = getattr(image_processor, "size", None) or {} + img_width = size.get("width", 224) + img_height = size.get("height", 224) + + image_overrides = mm_options.get("image") if mm_options else None + audio_overrides = mm_options.get("audio") if mm_options else None + video_overrides = mm_options.get("video") if mm_options else None + + data: MultiModalDataDict = { + "image": self._get_dummy_images( + width=img_width, + height=img_height, + num_images=num_images, + overrides=image_overrides, + ), + } + + if num_audios > 0: + audio_len = processor.feature_extractor.fft_length + data["audio"] = self._get_dummy_audios( + length=audio_len, + num_audios=num_audios, + overrides=audio_overrides, + ) + + if num_videos > 0: + data["video"] = self._get_dummy_videos( + width=img_width, + height=img_height, + num_frames=_VIDEO_MAX_FRAMES, + num_videos=num_videos, + overrides=video_overrides, + ) + + return data + + def _get_dummy_videos( + self, + *, + width: int, + height: int, + num_frames: int, + num_videos: int, + overrides: VideoDummyOptions | None = None, + ) -> list[VideoItem]: + num_frames = max(num_frames, 2) + videos = super()._get_dummy_videos( + width=width, + height=height, + num_frames=num_frames, + num_videos=num_videos, + overrides=overrides, + ) + videos = [v.copy() for v in videos] + + video_items: list[VideoItem] = [] + for video in videos: + video_num_frames = video.shape[0] + video_metadata = { + "fps": 2.0, + "duration": video_num_frames / 2.0, + "total_num_frames": video_num_frames, + "frames_indices": list(range(video_num_frames)), + "video_backend": "opencv", + "do_sample_frames": False, + } + video_items.append((video, video_metadata)) + + return video_items + + +# --------------------------------------------------------------------------- +# Multimodal processor +# --------------------------------------------------------------------------- + + +class Gemma4MultiModalProcessor(BaseMultiModalProcessor[Gemma4ProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + # Validate max_soft_tokens early and exit cleanly on bad values. + _SUPPORTED_SOFT_TOKENS = (70, 140, 280, 560, 1120) + + merged_kwargs = self.info.ctx.get_merged_mm_kwargs(mm_kwargs) + val = merged_kwargs.get("max_soft_tokens") + if val is None: + val = merged_kwargs.get("images_kwargs", {}).get("max_soft_tokens") + + if val is not None and val not in _SUPPORTED_SOFT_TOKENS: + logger.error( + "Unsupported max_soft_tokens value: %d. Valid values are %s. Exiting.", + val, + _SUPPORTED_SOFT_TOKENS, + ) + sys.exit(1) + + mm_data = dict(mm_data) + + # ---- VIDEO HANDLING ---- + # Gemma4 decomposes video into timestamped image frames. + # Each frame is processed with max_soft_tokens=70 through the + # same vision tower, matching transformers processing_gemma4.py. + video_outputs: dict[str, Any] = {} + if videos := mm_data.pop("videos", []): + processor = self.info.get_hf_processor() + + all_video_pixel_values: list[torch.Tensor] = [] + all_video_position_ids: list[torch.Tensor] = [] + video_num_soft_tokens_per_video: list[list[int]] = [] + video_timestamps_per_video: list[list[float]] = [] + video_frame_counts: list[int] = [] + + for item in videos: + video_array, metadata = item + + # Convert frames to PIL images + if isinstance(video_array, np.ndarray): + frames = [ + PILImage.fromarray(video_array[i]) + for i in range(video_array.shape[0]) + ] + else: + frames = list(video_array) + + # Compute timestamps from metadata (same as transformers) + fps = metadata.get("fps") or 24 + frame_indices = metadata.get("frames_indices", list(range(len(frames)))) + timestamps = [idx / fps for idx in frame_indices] + + # Process frames as images with max_soft_tokens=70 + video_mm_kwargs = dict(mm_kwargs) + video_mm_kwargs["max_soft_tokens"] = _VIDEO_MAX_SOFT_TOKENS + + dummy_prompt = ("\t" + processor.image_token) * len(frames) + + frame_outputs = super()._call_hf_processor( + prompt=dummy_prompt, + mm_data={"images": frames}, + mm_kwargs=video_mm_kwargs, + tok_kwargs=tok_kwargs, + ) + + # Remap HF key name + if "image_position_ids" in frame_outputs: + frame_outputs["pixel_position_ids"] = frame_outputs.pop( + "image_position_ids" + ) + + all_video_pixel_values.append(frame_outputs["pixel_values"]) + all_video_position_ids.append(frame_outputs["pixel_position_ids"]) + + # Compute soft tokens per frame + num_soft_per_frame = [] + for img in frames: + w, h = img.size + n = self.info._compute_num_soft_tokens( + w, h, max_soft_tokens=_VIDEO_MAX_SOFT_TOKENS + ) + num_soft_per_frame.append(n) + + video_num_soft_tokens_per_video.append(num_soft_per_frame) + video_timestamps_per_video.append(timestamps) + video_frame_counts.append(len(frames)) + + # Build expanded replacement text and replace the + # <|video|> placeholder in the prompt. + # Use split(token, 1) to avoid collision — the + # replacement text itself contains <|video|> tokens. + ts_strs = [f"{int(s // 60):02d}:{int(s % 60):02d}" for s in timestamps] + replacement = " ".join( + f"{t} {processor.boi_token}" + f"{processor.video_token * n}" + f"{processor.eoi_token}" + for t, n in zip(ts_strs, num_soft_per_frame) + ) + parts = prompt.split(processor.video_token, 1) + if len(parts) == 2: + prompt = parts[0] + replacement + parts[1] + + video_outputs = { + "pixel_values_videos": torch.cat(all_video_pixel_values, dim=0), + "pixel_position_ids_videos": torch.cat(all_video_position_ids, dim=0), + "video_frame_counts": torch.tensor(video_frame_counts), + "video_num_soft_tokens": video_num_soft_tokens_per_video, + "video_timestamps": video_timestamps_per_video, + } + + # The processor accepts 'audio' not 'audios'. + if "audios" in mm_data: + mm_data["audio"] = mm_data.pop("audios") + + # Warn if any audio waveform exceeds the model's max duration. + if "audio" in mm_data: + processor = self.info.get_hf_processor() + sr = processor.feature_extractor.sampling_rate + max_tokens = processor.audio_seq_length + ms_per_tok = processor.audio_ms_per_token + max_duration_s = max_tokens * ms_per_tok / 1000.0 + audios = mm_data["audio"] + if not isinstance(audios, (list, tuple)): + audios = [audios] + for i, waveform in enumerate(audios): + duration_s = len(waveform) / sr + if duration_s > max_duration_s: + logger.warning( + "Audio duration exceeds max: %f > %f seconds", + duration_s, + max_duration_s, + ) + # vLLM's call_hf_processor (context.py) re-merges + # mm_processor_kwargs from the model config on every call via: + # config_kwargs | incoming_kwargs (right side wins) + # + # If we strip max_soft_tokens from incoming, the re-merge puts + # back the config's global default (e.g. 280), ignoring any + # per-prompt override. Instead, we keep it in the kwargs with + # the validated per-prompt value so it wins during the merge. + # + # NOTE: This requires a corresponding type annotation on the + # HF side (Gemma4ProcessorKwargs.images_kwargs) so that + # _merge_kwargs routes max_soft_tokens into images_kwargs. + patched_mm_kwargs = dict(mm_kwargs) + if val is not None: + patched_mm_kwargs["max_soft_tokens"] = val + + processed_outputs = super()._call_hf_processor( + prompt, + mm_data, + patched_mm_kwargs, + tok_kwargs, + ) + + # HF uses 'image_position_ids'; vLLM uses 'pixel_position_ids'. + # Remap here to keep a single translation point. + if "image_position_ids" in processed_outputs: + processed_outputs["pixel_position_ids"] = processed_outputs.pop( + "image_position_ids" + ) + + if "input_features" in processed_outputs: + # Keep padded features for batched audio tower execution. + processed_outputs["input_features_padded"] = processed_outputs[ + "input_features" + ] + # Unpad per-item so each item's cache entry is self-contained. + unpadded_features = [ + f[mask] + for f, mask in zip( + processed_outputs["input_features"], + processed_outputs["input_features_mask"], + ) + ] + processed_outputs["input_features"] = unpadded_features + + # Merge video outputs into the final result + combined_outputs = dict(processed_outputs, **video_outputs) + return BatchFeature(combined_outputs) + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + fields = dict( + pixel_values=MultiModalFieldConfig.batched("image"), + pixel_position_ids=MultiModalFieldConfig.batched("image"), + input_features_padded=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio"), + ) + + # Video fields: frames stored flat, split per video by + # video_frame_counts. + video_frame_counts = hf_inputs.get("video_frame_counts") + if video_frame_counts is not None: + vfc = video_frame_counts + if not isinstance(vfc, torch.Tensor): + vfc = torch.tensor(vfc) + fields.update( + pixel_values_videos=( + MultiModalFieldConfig.flat_from_sizes("video", vfc) + ), + pixel_position_ids_videos=( + MultiModalFieldConfig.flat_from_sizes("video", vfc) + ), + video_frame_counts=MultiModalFieldConfig.batched( + "video", + ), + video_num_soft_tokens=MultiModalFieldConfig.batched( + "video", keep_on_cpu=True + ), + video_timestamps=MultiModalFieldConfig.batched( + "video", keep_on_cpu=True + ), + ) + + return fields + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, Any], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + + prompt_updates = [] + + if "image" in mm_items: + # Target image_token (<|image|>) — the single placeholder the + # Gemma4 chat template inserts once per image in the prompt. + # vLLM tokenizes the prompt without token expansion, so only + # one image_token exists per image in the token stream. + # The replacement expands it to the full image sequence + # (boi + N×image_token + eoi, where N = max_soft_tokens). + image_token = hf_processor.image_token + + def get_replacement_image(item_idx: int): + images = mm_items.get_items("image", ImageProcessorItems) + image_size = images.get_image_size(item_idx) + # Resolve the effective max_soft_tokens by merging + # per-prompt kwargs with the config-level defaults, + # consistent with how _call_hf_processor resolves it. + # Without this merge, a missing per-prompt override + # would fall back to vision_cfg.default_output_length + # instead of the config's mm_processor_kwargs default. + merged_kwargs = self.info.ctx.get_merged_mm_kwargs( + hf_processor_mm_kwargs, + ) + max_soft_tokens = merged_kwargs.get("max_soft_tokens") + return self.info.get_image_repl( + image_width=image_size.width, + image_height=image_size.height, + processor=hf_processor, + max_soft_tokens=max_soft_tokens, + ) + + prompt_updates.append( + PromptReplacement( + modality="image", + target=image_token, + replacement=get_replacement_image, + ) + ) + + if "video" in mm_items: + video_token = hf_processor.video_token + + def get_replacement_video(item_idx: int): + out_item = out_mm_kwargs["video"][item_idx] + timestamps = out_item["video_timestamps"].data + num_soft = out_item["video_num_soft_tokens"].data + return self.info.get_video_repl( + timestamps=timestamps, + num_soft_tokens_per_frame=num_soft, + processor=hf_processor, + ) + + prompt_updates.append( + PromptReplacement( + modality="video", + target=video_token, + replacement=get_replacement_video, + ) + ) + + if "audio" in mm_items: + audio_token = hf_processor.audio_token + + def get_replacement_audio(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio_len = audios.get_audio_length(item_idx) + return self.info.get_audio_repl( + audio_len=audio_len, + processor=hf_processor, + ) + + prompt_updates.append( + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_audio, + ) + ) + + return prompt_updates + + # NOTE: Gemma3/Gemma3n override _apply_token_matches and + # _find_mm_placeholders to merge adjacent newline tokens that arise + # when full_image_sequence contains "\n\n" wrappers. Gemma4's + # full_image_sequence has NO newlines (just BOI + 280×image_token + + # EOI), so the base class implementations work correctly as-is. + + +# --------------------------------------------------------------------------- +# Multimodal embedder +# --------------------------------------------------------------------------- + + +class Gemma4MultimodalEmbedder(nn.Module): + """Projects vision/audio soft tokens into LM embedding space. + + Architecture: + inputs_embeds → embedding_projection → embedding_post_projection_norm + + Unlike Gemma3n which has separate hard/soft embedding paths with + per-path normalization and a learned embedding table, Gemma4 uses a + simplified 2-layer design: a linear projection followed by RMSNorm + (without learnable scale). The checkpoint confirms this — only + ``embedding_projection.weight`` exists; there is no embedding table + or pre-projection norm weights. + """ + + def __init__( + self, + multimodal_config: Gemma4VisionConfig | Gemma4AudioConfig, + text_config: Gemma4TextConfig, + ): + super().__init__() + + self.eps = multimodal_config.rms_norm_eps + self.text_hidden_size = text_config.hidden_size + + # Audio tower uses output_proj_dims (1536) rather than hidden_size + # (1024); vision uses hidden_size (768) directly. + embedding_dim = ( + getattr(multimodal_config, "output_proj_dims", None) + or multimodal_config.hidden_size + ) + + self.embedding_projection = ReplicatedLinear( + embedding_dim, + self.text_hidden_size, + bias=False, + ) + + self.embedding_post_projection_norm = RMSNorm( + self.text_hidden_size, + eps=self.eps, + has_weight=False, + ) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + """Project soft tokens from a multimodal tower into LM space.""" + embs_proj, _ = self.embedding_projection(inputs_embeds) + return self.embedding_post_projection_norm(embs_proj) + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + + +@MULTIMODAL_REGISTRY.register_processor( + Gemma4MultiModalProcessor, + info=Gemma4ProcessingInfo, + dummy_inputs=Gemma4DummyInputsBuilder, +) +class Gemma4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # Maps checkpoint prefixes to vLLM module paths. + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.embed_audio.": "embed_audio.", + "model.embed_vision.": "embed_vision.", + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_tower.", + "model.audio_tower.": "audio_tower.", + "lm_head.": "language_model.lm_head.", + "model": "language_model.model", + } + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.quant_config = quant_config + self.multimodal_config = multimodal_config + + # ---- Vision tower (shared by image and video) ---- + with self._mark_tower_model(vllm_config, {"image", "video"}): + self.vision_tower = AutoModel.from_config(config=config.vision_config) + self.embed_vision = Gemma4MultimodalEmbedder( + config.vision_config, config.text_config + ) + + # ---- Audio tower (variants with audio_config) ---- + if config.audio_config is not None: + with self._mark_tower_model(vllm_config, "audio"): + self.audio_tower = AutoModel.from_config(config=config.audio_config) + # AutoModel.from_config does NOT call post_init(), + # which is needed to initialize buffers that are absent + # from the checkpoint (e.g. inv_timescales for relative + # position embeddings, softcap, gradient_clipping). + self.audio_tower.post_init() + self.embed_audio = Gemma4MultimodalEmbedder( + config.audio_config, config.text_config + ) + else: + self.audio_tower = None + self.embed_audio = None + + # ---- Language model (vLLM optimised) ---- + with self._mark_language_model(vllm_config): + self.language_model: Gemma4ForCausalLM = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Gemma4ForCausalLM"], + ) + + # Pre-allocate PLE buffer for CUDA graph compatibility. + # Some variants have hidden_size_per_layer_input=None (no PLE). + ple_dim = config.text_config.hidden_size_per_layer_input + if ple_dim is not None: + self.per_layer_embeddings = torch.zeros( + vllm_config.scheduler_config.max_num_batched_tokens, + config.text_config.num_hidden_layers, + ple_dim, + device=(self.language_model.model.embed_tokens.weight.device), + dtype=(self.language_model.model.embed_tokens.weight.dtype), + ) + else: + self.per_layer_embeddings = None + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + # --- MixtureOfExperts delegation to language_model --- + self.expert_weights = self.language_model.expert_weights + self.moe_layers = self.language_model.moe_layers + self.num_moe_layers = self.language_model.num_moe_layers + self.num_logical_experts = self.language_model.num_logical_experts + self.num_physical_experts = self.language_model.num_physical_experts + self.num_local_physical_experts = self.language_model.num_local_physical_experts + self.num_routed_experts = self.language_model.num_routed_experts + self.num_expert_groups = self.language_model.num_expert_groups + self.num_shared_experts = self.language_model.num_shared_experts + self.num_redundant_experts = self.language_model.num_redundant_experts + + # ------------------------------------------------------------------ # + # Input parsing + # ------------------------------------------------------------------ # + + def _parse_and_validate_image_input( + self, **kwargs: object + ) -> Gemma4ImageInputs | None: + pixel_values = kwargs.pop("pixel_values", None) + pixel_position_ids = kwargs.pop("pixel_position_ids", None) + image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma4 does not support image_embeds." + if pixel_values is None: + return None + return Gemma4ImagePixelInputs( + pixel_values=pixel_values, + pixel_position_ids=pixel_position_ids, + ) + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> Gemma4AudioInputs | None: + input_features_padded = kwargs.pop("input_features_padded", None) + if input_features_padded is None: + return None + input_features_mask = kwargs.pop("input_features_mask", None) + if input_features_mask is None: + return None + return Gemma4AudioInputs( + input_features_padded=input_features_padded, + input_features_mask=input_features_mask, + ) + + def _parse_and_validate_video_input( + self, **kwargs: object + ) -> dict[str, torch.Tensor] | None: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + pixel_position_ids_videos = kwargs.pop("pixel_position_ids_videos", None) + video_frame_counts = kwargs.pop("video_frame_counts", None) + if pixel_values_videos is None: + return None + return { + "pixel_values_videos": pixel_values_videos, + "pixel_position_ids_videos": pixel_position_ids_videos, + "video_frame_counts": video_frame_counts, + } + + def _parse_and_validate_multimodal_inputs( + self, **kwargs: object + ) -> dict[str, Gemma4ImageInputs | Gemma4AudioInputs | Gemma4VideoInputs | None]: + mm_input_by_modality = {} + for input_key in list(kwargs): + if ( + input_key in ("pixel_values", "image_embeds") + and "image" not in mm_input_by_modality + ): + mm_input_by_modality["image"] = self._parse_and_validate_image_input( + **kwargs + ) + if ( + input_key == "pixel_values_videos" + and "video" not in mm_input_by_modality + ): + mm_input_by_modality["video"] = self._parse_and_validate_video_input( + **kwargs + ) + if ( + input_key == "input_features_padded" + and "audio" not in mm_input_by_modality + ): + mm_input_by_modality["audio"] = self._parse_and_validate_audio_input( + **kwargs + ) + return mm_input_by_modality + + # ------------------------------------------------------------------ # + # Image processing + # ------------------------------------------------------------------ # + + def _process_image_input( + self, + image_input: Gemma4ImageInputs, + ) -> list[torch.Tensor]: + pixel_values = image_input["pixel_values"] + pixel_position_ids = image_input["pixel_position_ids"] + + # The HF image processor now outputs pre-patchified data: + # pixel_values: (num_images, max_patches, patch_pixels) + # pixel_position_ids: (num_images, max_patches, 2) + # We call the vision tower's forward() directly, which handles + # patch embedding, encoding, pooling, padding removal, and + # optional standardization internally. + vt = self.vision_tower + pooling_k2 = self.config.vision_config.pooling_kernel_size**2 + + # TODO: Move this per-image loop into the input processor to + # reduce dynamism at the model runner / engine core. This + # requires spatially padding all images to uniform (H_max, + # W_max) in _call_hf_processor() so they arrive as a single + # stacked tensor, tracking padded regions via image_sizes + # metadata, and validating numerical equivalence with the + # current per-image path. + # + # Process each image individually through the vision tower. + # The vision tower's forward() strips padding and returns a + # flat tensor of valid tokens. We process per-image to get + # variable-length outputs matching the dynamic token count + # from get_image_repl. + per_image_features = [] + for i in range(pixel_values.shape[0]): + pv = pixel_values[i].unsqueeze(0) # (1, max_patches, patch_pixels) + pp = pixel_position_ids[i].unsqueeze(0) # (1, max_patches, 2) + + # Derive the pooler's output_length from the total patch + # count (including padding). The vision tower encoder + # processes ALL patches — padding patches get zero hidden + # states but still occupy sequence positions. The pooler's + # _avg_pool_by_positions requires: + # input_seq_len / output_length == k² + # where k == pooling_kernel_size. The image processor + # allocates max_patches = max_soft_tokens * k² total slots, + # so output_length = max_patches / k² == max_soft_tokens. + # Without this, the pooler falls back to + # config.image_seq_length (e.g. 280), which fails when a + # different max_soft_tokens was used at preprocessing time. + max_patches = pv.shape[1] + output_length = max_patches // pooling_k2 + + vt_output = vt(pv, pp, output_length=output_length) + # last_hidden_state: (num_valid_tokens, hidden_size) + # — already flat with padding stripped by the vision tower + per_image_features.append(vt_output.last_hidden_state) + + # Project each image's features into LM embedding space. + # Per-image loop is required because images have variable + # token counts after padding removal. + # Cast to match the projection layer's dtype (model may be + # bf16 while the vision tower outputs fp32). + target_dtype = self.embed_vision.embedding_projection.weight.dtype + return [ + self.embed_vision(inputs_embeds=img.unsqueeze(0).to(target_dtype)).squeeze( + 0 + ) + for img in per_image_features + ] + + # ------------------------------------------------------------------ # + # Video processing (frames through vision tower) + # ------------------------------------------------------------------ # + + def _process_video_input( + self, + video_input: dict[str, torch.Tensor], + ) -> list[torch.Tensor]: + """Process video frames through the vision tower. + + Reuses the image processing pipeline — Gemma4 has no separate + video tower; video frames are just images at lower resolution + (max_soft_tokens=70). + + Returns one concatenated embedding tensor per video (not per + frame), because vLLM treats one video as one multimodal item. + The flat_from_sizes field config groups all frames of a video + together, so embed_multimodal must return one tensor per video. + """ + pixel_values = video_input["pixel_values_videos"] + pixel_position_ids = video_input["pixel_position_ids_videos"] + frame_counts = video_input["video_frame_counts"] + + vt = self.vision_tower + pooling_k2 = self.config.vision_config.pooling_kernel_size**2 + target_dtype = self.embed_vision.embedding_projection.weight.dtype + + # Split flat tensors into per-video chunks + if isinstance(frame_counts, torch.Tensor): + fc_list = frame_counts.tolist() + else: + fc_list = list(frame_counts) + + pv_per_video = torch.split(pixel_values, fc_list, dim=0) + pp_per_video = torch.split(pixel_position_ids, fc_list, dim=0) + + per_video_embeddings = [] + for pv_chunk, pp_chunk in zip(pv_per_video, pp_per_video): + frame_embs = [] + for i in range(pv_chunk.shape[0]): + pv = pv_chunk[i].unsqueeze(0) + pp = pp_chunk[i].unsqueeze(0) + + max_patches = pv.shape[1] + output_length = max_patches // pooling_k2 + + vt_output = vt(pv, pp, output_length=output_length) + frame_emb = self.embed_vision( + inputs_embeds=( + vt_output.last_hidden_state.unsqueeze(0).to(target_dtype) + ) + ).squeeze(0) + frame_embs.append(frame_emb) + + # Concatenate all frames of this video into one tensor. + per_video_embeddings.append(torch.cat(frame_embs, dim=0)) + + return per_video_embeddings + + # ------------------------------------------------------------------ # + # Audio processing + # ------------------------------------------------------------------ # + + def _process_audio_input( + self, + audio_input: Gemma4AudioInputs, + ) -> list[torch.Tensor]: + input_features = audio_input["input_features_padded"].squeeze(1) + input_features_mask = audio_input["input_features_mask"].squeeze(1) + + # Run audio tower — mask uses standard HF convention + # (True=valid, False=padding). + audio_outputs = self.audio_tower(input_features, input_features_mask) + if isinstance(audio_outputs, tuple): + audio_encodings, audio_mask = audio_outputs + else: + audio_encodings = audio_outputs.last_hidden_state + audio_mask = audio_outputs.attention_mask + + # Project into LM embedding space. + audio_features = self.embed_audio(inputs_embeds=audio_encodings) + + # Strip padding per-batch element: only keep real (non-padding) + # tokens. audio_mask is True for valid positions (HF convention). + per_audio = [] + for enc, mask in zip(audio_features, audio_mask, strict=True): + per_audio.append(enc[mask]) # [num_real, hidden_size] + + return per_audio + + # ------------------------------------------------------------------ # + # MultiModalEmbeddings interface + # ------------------------------------------------------------------ # + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs) + multimodal_embeddings: list[torch.Tensor] = [] + + for modality, multimodal_input in mm_input_by_modality.items(): + if multimodal_input is None: + continue + if modality == "image": + multimodal_embeddings.extend( + self._process_image_input(multimodal_input) + ) + elif modality == "video": + multimodal_embeddings.extend( + self._process_video_input(multimodal_input) + ) + elif modality == "audio": + multimodal_embeddings.extend( + self._process_audio_input(multimodal_input) + ) + + return multimodal_embeddings + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + # Cache per-layer embeddings (PLE) for the language model's + # forward pass. During profiling embed_input_ids is not called, + # so the pre-allocated zeros are used instead. + if self.per_layer_embeddings is not None: + # Mask multimodal tokens (image/audio) to 0 for PLE + # computation (using token_type_ids == 0 as text_mask). + # Replicate this: map image token positions to token 0. + if is_multimodal is not None: + is_multimodal = is_multimodal.to(input_ids.device) + ple_input_ids = torch.where( + is_multimodal, torch.zeros_like(input_ids), input_ids + ) + else: + ple_input_ids = input_ids + + per_layer_inputs = self.language_model.model.get_per_layer_inputs( + ple_input_ids + ) + if per_layer_inputs is not None: + per_layer_inputs = per_layer_inputs.reshape( + -1, + self.config.text_config.num_hidden_layers, + self.config.text_config.hidden_size_per_layer_input, + ) + self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_( + per_layer_inputs + ) + + if multimodal_embeddings is None or is_multimodal is None: + return super().embed_input_ids(input_ids) + + return super().embed_input_ids( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + ) + + # ------------------------------------------------------------------ # + # Forward + # ------------------------------------------------------------------ # + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> IntermediateTensors: + if intermediate_tensors is not None: + inputs_embeds = None + + # Select the pre-cached PLEs for this batch (None when PLE + # is disabled for variants without PLE). + per_layer_inputs = ( + self.per_layer_embeddings[: inputs_embeds.shape[0]] + if self.per_layer_embeddings is not None and inputs_embeds is not None + else None + ) + + hidden_states = self.language_model.model( + input_ids, + positions, + per_layer_inputs=per_layer_inputs, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + # ------------------------------------------------------------------ # + # Weight loading + # ------------------------------------------------------------------ # + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + # Some checkpoints have vestigial embed_vision.embedding and + # embed_audio.embedding weights from the Gemma3n architecture + # that are not used by Gemma4's MultimodalEmbedder (which only + # has embedding_projection + embedding_post_projection_norm). + ignore_prefixes = [ + "embed_vision.embedding.", + "embed_audio.embedding.", + ] + # Models without audio tower should skip + # audio weights entirely. + if self.audio_tower is None: + ignore_prefixes.extend( + [ + "audio_tower.", + "embed_audio.", + ] + ) + loader = AutoWeightsLoader( + self, + ignore_unexpected_prefixes=ignore_prefixes, + ) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + # ------------------------------------------------------------------ # + # LoRA / multimodal mapping + # ------------------------------------------------------------------ # + + def get_mm_mapping(self) -> MultiModelKeys: + """Get the module prefix mapping for multimodal models.""" + return MultiModelKeys.from_string_field( + language_model="language_model", + connector=["embed_vision", "embed_audio"], + tower_model=["vision_tower", "audio_tower"], + ) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality == "image": + return "" + if modality == "audio": + return "" + if modality == "video": + return "<|video|>" + raise ValueError(f"Unsupported modality: {modality}") diff --git a/vllm/model_executor/models/gemma4_utils.py b/vllm/model_executor/models/gemma4_utils.py new file mode 100644 index 000000000..061f8a1cc --- /dev/null +++ b/vllm/model_executor/models/gemma4_utils.py @@ -0,0 +1,292 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. + +"""Gemma4 output parsing utilities for offline inference. + +Standalone functions that parse decoded model text to extract structured +thinking content and tool calls from Gemma4 models. These are pure-Python +utilities with zero heavy dependencies — they work on raw decoded strings +from any inference backend (vLLM, HuggingFace, TGI, etc.). + +Usage with vLLM offline inference:: + + from vllm import LLM, SamplingParams + from vllm.model_executor.models.gemma4_utils import ( + parse_output, + parse_tool_calls, + ) + + llm = LLM(model="google/gemma-4-it") + outputs = llm.generate(prompt, SamplingParams(...)) + text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False) + + # Extract thinking / answer (works with or without enable_thinking) + result = parse_output(text) + print(result["thinking"]) # chain-of-thought or None + print(result["answer"]) # final answer + + # Extract tool calls + tool_calls = parse_tool_calls(text) + for tc in tool_calls: + print(f"{tc['name']}({tc['arguments']})") + +Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users +do not need a transformers dependency for output parsing. +""" + +import json + +import regex as re + +# ---- Thinking Mode Utility ---- + +# Thinking delimiter tokens as they appear in decoded text. +# Gemma4 uses <|channel> (start) and (end) as thinking delimiters. +_THINKING_START_TAG = "<|channel>" +_THINKING_END_TAG = "" + +# Sentinel tokens that may appear in decoded output. +_TURN_END_TAG = "" + + +def parse_thinking_output(text: str) -> dict[str, str | None]: + """Parse decoded Gemma4 model output. + + Use this on **all** Gemma4 output regardless of whether thinking mode + was enabled. It handles three cases: + + 1. **Thinking enabled, tags present** — splits on ``<|channel>``/ + ```` to separate chain-of-thought from the answer and + strips the ``thought\\n`` role label. + 2. **Thinking disabled, spurious label** — strips the bare + ``thought\\n`` prefix that some Gemma4 models emit even + without thinking mode. + 3. **Clean output** — returns the text unchanged. + + The answer text is always cleaned of trailing sentinel tokens + (````, ````, etc.). + + Args: + text: Decoded model output text (from ``tokenizer.decode(...)``). + + Returns: + A dict with keys: + - ``"thinking"``: The chain-of-thought text, or ``None`` if no + thinking delimiters were found. + - ``"answer"``: The final answer text. + + Example:: + + >>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output + >>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False) + >>> result = parse_thinking_output(output_text) + >>> print(result["thinking"]) # chain-of-thought reasoning or None + >>> print(result["answer"]) # final answer + """ + if _THINKING_END_TAG in text: + parts = text.split(_THINKING_END_TAG, 1) + thinking_block = parts[0] + answer = _clean_answer(parts[1]) + + # Extract thinking content: strip the start tag if present + if _THINKING_START_TAG in thinking_block: + thinking = thinking_block.split(_THINKING_START_TAG, 1)[1] + else: + thinking = thinking_block + + # Strip the "thought\n" channel role label the model emits inside + # <|channel>thought\n... (analogous to "user\n" in + # <|turn>user\n...). + thinking = _strip_thought_label(thinking.strip()) + thinking = thinking.strip() + + return {"thinking": thinking, "answer": answer} + + # No thinking delimiters found. + # Strip spurious "thought\n" role label that some Gemma4 models sometimes + # emit even without thinking mode enabled, then clean trailing tokens. + answer = _strip_thought_label(text) + answer = _clean_answer(answer) + return {"thinking": None, "answer": answer} + + +def _strip_thought_label(text: str) -> str: + """Strip the spurious ``thought\\n`` label from the start of text. + + Only strips when ``thought`` appears as the very first word followed by + a newline — preserving the word ``thought`` in any other context. + """ + if text.startswith("thought\n"): + return text[len("thought\n") :] + return text + + +def _clean_answer(text: str) -> str: + """Clean trailing sentinel tokens from the answer text. + + Strips ````, ````, and surrounding whitespace that the + model appends at the end of its response. + """ + text = text.strip() + # Strip trailing (Gemma4 turn-end marker) + if text.endswith(_TURN_END_TAG): + text = text[: -len(_TURN_END_TAG)].rstrip() + # Strip trailing if present + if text.endswith(""): + text = text[:-5].rstrip() + return text + + +# ---- Tool Call Parsing Utility ---- +# +# NOTE: For the OpenAI-compatible API server tool parser (streaming + +# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py. +# This module provides offline inference utilities for direct user import. + +# Tool call delimiter tokens as they appear in decoded text. +# Standard format: <|tool_call>call:name{args} +_TOOL_CALL_START_TAG = "<|tool_call>" +_TOOL_CALL_END_TAG = "" +_TOOL_RESPONSE_START_TAG = "<|tool_response>" + +# Gemma4 escape token as it appears in decoded text. +_ESCAPE_TOKEN = '<|"|>' + + +def _parse_tool_arguments(args_str: str) -> dict[str, str]: + """Parse tool call arguments from the Gemma4 compact format. + + Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback + to heuristic key-value extraction. Also tolerates the slightly different + ``key: "value"`` format (space + plain quotes) that some chat templates + produce. + + Args: + args_str: Raw argument string from inside ``call:name{...}``. + + Returns: + Dictionary of argument name → value. + """ + if not args_str or not args_str.strip(): + return {} + + # Replace Gemma4 escape tokens with standard quotes. + cleaned = args_str.replace(_ESCAPE_TOKEN, '"') + + # Try JSON parsing first (handles nested values, arrays, etc.). + try: + parsed = json.loads("{" + cleaned + "}") + # Ensure all values are strings for consistency. + return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()} + except (json.JSONDecodeError, ValueError): + pass + + # Fallback: extract key:"value" pairs (allow optional space after colon). + arguments = {} + for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned): + arguments[key] = value + + if not arguments: + # Last resort: extract key:value pairs (unquoted). + for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str): + arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "") + + return arguments + + +def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]: + """Parse tool calls from decoded Gemma4 model output. + + Uses a tiered parsing strategy to handle known output variations in + Gemma4 models, which may emit + non-standard tool call formats. + + Parsing tiers: + 1. **Standard**: ``<|tool_call>call:name{args}`` + (special token IDs 48/49 in decoded text) + 2. **Fallback** (when ``strict=False``): bare ``call:name{args}`` + patterns, including ``name{args}`` (fragmented tokens from + multimodal inputs) + + Args: + text: Decoded model output text (from ``tokenizer.decode(..., + skip_special_tokens=False)``). + strict: If ``True``, only match the standard ``<|tool_call>`` format. + If ``False`` (default), also try fallback patterns for + known Gemma4 output variations. + + Returns: + A list of dicts, each with keys: + - ``"name"``: The tool function name (e.g. ``"get_weather"``). + - ``"arguments"``: A dict of argument name → value. + + Example:: + + >>> from vllm.model_executor.models.gemma4_utils import ( + ... parse_tool_calls + ... ) + >>> output = tokenizer.decode(outputs[0], skip_special_tokens=False) + >>> tool_calls = parse_tool_calls(output) + >>> for tc in tool_calls: + ... print(f"Call: {tc['name']}({tc['arguments']})") + """ + results = [] + + # Tier 1: Standard format with special tokens. + # <|tool_call>call:name{args} + # Note: Some Gemma4 models emit instead of . + standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:|)" + for match in re.finditer(standard_pattern, text, re.DOTALL): + name, args_str = match.group(1), match.group(2) + results.append( + { + "name": name, + "arguments": _parse_tool_arguments(args_str), + } + ) + + if results or strict: + return results + + # Tier 2: Fallback for known Gemma4 output variations. + # Matches: name{args}, call:name{args}, or bare call:name{args} + fallback_pattern = r"(?:|(?:^|\s)call:)(\w+)\{(.*?)\}" + for match in re.finditer(fallback_pattern, text, re.DOTALL): + name, args_str = match.group(1), match.group(2) + results.append( + { + "name": name, + "arguments": _parse_tool_arguments(args_str), + } + ) + + return results + + +def has_tool_response_tag(text: str) -> bool: + """Check if model output properly ends with a tool response tag. + + Some Gemma4 models sometimes emit ```` instead of + ``<|tool_response>`` after a tool call. This helper detects + whether the model used the proper termination, so callers can + decide whether to inject ``<|tool_response>`` into the next prompt. + + Args: + text: Decoded model output text. + + Returns: + ``True`` if the output ends with ``<|tool_response>`` + (proper behavior), ``False`` otherwise. + + Example:: + + >>> from vllm.model_executor.models.gemma4_utils import ( + ... has_tool_response_tag + ... ) + >>> if not has_tool_response_tag(model_output): + ... # Model used instead — inject <|tool_response> manually + ... next_prompt = "<|tool_response>" + tool_result + """ + stripped = text.rstrip() + return stripped.endswith(_TOOL_RESPONSE_START_TAG) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 2c72c5d68..0ed348adf 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = { "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), "Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"), "Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"), + "Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"), "Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"), @@ -381,6 +382,7 @@ _MULTIMODAL_MODELS = { "gemma3n_mm", "Gemma3nForConditionalGeneration", ), + "Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"), "GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"), "GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"), "Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"), diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 8abaa557f..4489e0c08 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -233,8 +233,15 @@ class AutoWeightsLoader: ): """ Add tensor names that are not in the model params that may be in the - safetensors, e.g., batch normalization stats. + safetensors, e.g., batch normalization stats and registered buffers. """ + # Add persistent registered buffers. + # Non-persistent buffers are excluded, matching PyTorch state_dict(). + non_persistent = getattr(module, "_non_persistent_buffers_set", set()) + for buf_name, buf in module.named_buffers(recurse=False): + if buf_name not in child_params and buf_name not in non_persistent: + child_params[buf_name] = buf + if isinstance( module, ( diff --git a/vllm/reasoning/__init__.py b/vllm/reasoning/__init__.py index 8c78db6f1..2d57b9336 100644 --- a/vllm/reasoning/__init__.py +++ b/vllm/reasoning/__init__.py @@ -32,6 +32,10 @@ _REASONING_PARSERS_TO_REGISTER = { "ernie45_reasoning_parser", "Ernie45ReasoningParser", ), + "gemma4": ( + "gemma4_reasoning_parser", + "Gemma4ReasoningParser", + ), "glm45": ( "deepseek_v3_reasoning_parser", "DeepSeekV3ReasoningWithThinkingParser", diff --git a/vllm/reasoning/gemma4_reasoning_parser.py b/vllm/reasoning/gemma4_reasoning_parser.py new file mode 100644 index 000000000..efcdcca23 --- /dev/null +++ b/vllm/reasoning/gemma4_reasoning_parser.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from vllm.entrypoints.openai.engine.protocol import DeltaMessage +from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser +from vllm.tokenizers import TokenizerLike + +if TYPE_CHECKING: + from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, + ) + from vllm.entrypoints.openai.responses.protocol import ResponsesRequest + +# Role label that Gemma4 emits at the start of the thinking channel. +# The model generates: <|channel>thought\n...reasoning... +# This prefix must be stripped to expose only the actual reasoning content. +_THOUGHT_PREFIX = "thought\n" + + +class Gemma4ReasoningParser(BaseThinkingReasoningParser): + """ + Reasoning parser for Google Gemma4 thinking models. + + Gemma4 uses <|channel>... tokens to delimit reasoning/thinking + content within its output. Thinking mode is activated by passing + ``enable_thinking=True`` in the chat template kwargs, which injects a + system turn containing <|think|> (token 98) to trigger chain-of-thought + reasoning. + + Output pattern when thinking is enabled:: + + <|channel>thought + ...chain of thought reasoning... + Final answer text here. + + The ``thought\\n`` role label inside the channel delimiters is a + structural artefact (analogous to ``user\\n`` in ``<|turn>user\\n...``). + This parser strips it so that downstream consumers see only the + actual reasoning text, consistent with the offline parser + (``vllm.reasoning.gemma4_utils._strip_thought_label``). + """ + + def __init__(self, tokenizer: TokenizerLike, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + # Instance state for streaming prefix stripping. + # Tracks only the reasoning text received from the base parser, + # independent of current_text (which may contain pre-reasoning + # content and lacks special token text due to + # skip_special_tokens=True). + self._reasoning_text: str = "" + self._prefix_stripped: bool = False + + @property + def start_token(self) -> str: + """The token that starts reasoning content.""" + return "<|channel>" + + @property + def end_token(self) -> str: + """The token that ends reasoning content.""" + return "" + + # ------------------------------------------------------------------ + # Non-streaming path + # ------------------------------------------------------------------ + + def extract_reasoning( + self, + model_output: str, + request: "ChatCompletionRequest | ResponsesRequest", + ) -> tuple[str | None, str | None]: + """Extract reasoning, stripping the ``thought\\n`` role label.""" + if self.start_token not in model_output and self.end_token not in model_output: + # Default to content history if no tags are present + # (or if they were stripped) + return None, model_output + + reasoning, content = super().extract_reasoning(model_output, request) + if reasoning is not None: + reasoning = _strip_thought_label(reasoning) + return reasoning, content + + # ------------------------------------------------------------------ + # Streaming path + # ------------------------------------------------------------------ + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract streaming reasoning, stripping ``thought\\n`` from the + first reasoning delta(s). + + The ``thought\\n`` prefix may arrive as a single delta or split + across multiple deltas (e.g. ``"thought"`` then ``"\\n"``). We + buffer early reasoning tokens until we can determine whether the + prefix is present, then emit the buffered content minus the + prefix. + + Unlike the previous implementation which reconstructed accumulated + reasoning from ``current_text``, this uses instance state + (``_reasoning_text``) to track only the reasoning content returned + by the base parser. This is necessary because + ``skip_special_tokens=True`` (the vLLM default) causes the + ``<|channel>`` delimiter to be invisible in ``current_text``, + making it impossible to separate pre-reasoning content from + reasoning content via string matching. + """ + result = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if result is None: + return None + + if result.reasoning is None: + return result + + # Accumulate ONLY the reasoning text from base parser results. + # This is immune to pre-reasoning content pollution. + self._reasoning_text += result.reasoning + + # Once the prefix has been handled, all subsequent reasoning + # deltas pass through unchanged. + if self._prefix_stripped: + return result + + # ---- Prefix stripping logic ---- + + # Case 1: We've accumulated enough to confirm the prefix is + # present. Strip it and pass through the remainder. + if self._reasoning_text.startswith(_THOUGHT_PREFIX): + prefix_len = len(_THOUGHT_PREFIX) + # How much reasoning was accumulated before this delta? + prev_reasoning_len = len(self._reasoning_text) - len(result.reasoning) + if prev_reasoning_len >= prefix_len: + # Prefix was already consumed by prior deltas; this + # delta is entirely real content — pass through. + self._prefix_stripped = True + return result + else: + # Part or all of the prefix is in this delta. + chars_of_prefix_in_delta = prefix_len - prev_reasoning_len + stripped = result.reasoning[chars_of_prefix_in_delta:] + if stripped: + self._prefix_stripped = True + result.reasoning = stripped + return result + else: + # This entire delta was prefix — suppress it. + # Don't set _prefix_stripped yet; there may be more + # prefix chars to consume in the next delta. + if len(self._reasoning_text) >= prefix_len: + self._prefix_stripped = True + return None + + # Case 2: Accumulated text is a strict prefix of + # _THOUGHT_PREFIX (e.g. we've only seen "thou" so far). + # Buffer by suppressing — we can't yet tell if this will + # become the full prefix or diverge. + if _THOUGHT_PREFIX.startswith(self._reasoning_text): + return None + + # Case 3: Accumulated text doesn't match the thought prefix + # at all. This means prior deltas were buffered (suppressed + # by Case 2) but the text diverged. Re-emit the full + # accumulated text to avoid data loss. + self._prefix_stripped = True + result.reasoning = self._reasoning_text + return result + + +def _strip_thought_label(text: str) -> str: + """Remove the ``thought\\n`` role label from the beginning of text. + + Mirrors ``vllm.reasoning.gemma4_utils._strip_thought_label`` from the + offline parser. + """ + if text.startswith(_THOUGHT_PREFIX): + return text[len(_THOUGHT_PREFIX) :] + return text diff --git a/vllm/reasoning/gemma4_utils.py b/vllm/reasoning/gemma4_utils.py new file mode 100644 index 000000000..9cdac7203 --- /dev/null +++ b/vllm/reasoning/gemma4_utils.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. + +"""Gemma4 thinking/reasoning output parsing utilities for offline inference. + +Standalone functions that parse decoded model text to extract structured +thinking content from Gemma4 models. These are pure-Python utilities with +zero heavy dependencies — they work on raw decoded strings from any +inference backend (vLLM, HuggingFace, TGI, etc.). + +For the OpenAI-compatible API reasoning parser (streaming + +non-streaming), see ``vllm.reasoning.gemma4_reasoning_parser``. +For tool call parsing, see ``vllm.tool_parsers.gemma4_utils``. + +Usage with vLLM offline inference:: + + from vllm import LLM, SamplingParams + from vllm.reasoning.gemma4_utils import parse_thinking_output + + llm = LLM(model="google/gemma-4-it") + outputs = llm.generate(prompt, SamplingParams(...)) + text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False) + + # Extract thinking / answer (works with or without enable_thinking) + result = parse_thinking_output(text) + print(result["thinking"]) # chain-of-thought or None + print(result["answer"]) # final answer + +Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users +do not need a transformers dependency for output parsing. +""" + +# ---- Thinking Mode Utility ---- + +# Thinking delimiter tokens as they appear in decoded text. +# Gemma4 uses <|channel> (start) and (end) as thinking delimiters. +_THINKING_START_TAG = "<|channel>" +_THINKING_END_TAG = "" + +# Sentinel tokens that may appear in decoded output. +_TURN_END_TAG = "" + + +def parse_thinking_output(text: str) -> dict[str, str | None]: + """Parse decoded Gemma4 model output. + + Use this on **all** Gemma4 output regardless of whether thinking mode + was enabled. It handles three cases: + + 1. **Thinking enabled, tags present** — splits on ``<|channel>``/ + ```` to separate chain-of-thought from the answer and + strips the ``thought\\n`` role label. + 2. **Thinking disabled, spurious label** — strips the bare + ``thought\\n`` prefix that some Gemma4 models emit even + without thinking mode. + 3. **Clean output** — returns the text unchanged. + + The answer text is always cleaned of trailing sentinel tokens + (````, ````, etc.). + + Args: + text: Decoded model output text (from ``tokenizer.decode(...)``). + + Returns: + A dict with keys: + - ``"thinking"``: The chain-of-thought text, or ``None`` if no + thinking delimiters were found. + - ``"answer"``: The final answer text. + + Example:: + + >>> from vllm.reasoning.gemma4_utils import parse_thinking_output + >>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False) + >>> result = parse_thinking_output(output_text) + >>> print(result["thinking"]) # chain-of-thought reasoning or None + >>> print(result["answer"]) # final answer + """ + if _THINKING_END_TAG in text: + parts = text.split(_THINKING_END_TAG, 1) + thinking_block = parts[0] + answer = _clean_answer(parts[1]) + + # Extract thinking content: strip the start tag if present + if _THINKING_START_TAG in thinking_block: + thinking = thinking_block.split(_THINKING_START_TAG, 1)[1] + else: + thinking = thinking_block + + # Strip the "thought\n" channel role label the model emits inside + # <|channel>thought\n... (analogous to "user\n" in + # <|turn>user\n...). + thinking = _strip_thought_label(thinking.strip()) + thinking = thinking.strip() + + return {"thinking": thinking, "answer": answer} + + # No thinking delimiters found. + # Strip spurious "thought\n" role label that some Gemma4 models sometimes + # emit even without thinking mode enabled, then clean trailing tokens. + answer = _strip_thought_label(text) + answer = _clean_answer(answer) + return {"thinking": None, "answer": answer} + + +def _strip_thought_label(text: str) -> str: + """Strip the spurious ``thought\\n`` label from the start of text. + + Only strips when ``thought`` appears as the very first word followed by + a newline — preserving the word ``thought`` in any other context. + """ + if text.startswith("thought\n"): + return text[len("thought\n") :] + return text + + +def _clean_answer(text: str) -> str: + """Clean trailing sentinel tokens from the answer text. + + Strips ````, ````, and surrounding whitespace that the + model appends at the end of its response. + """ + text = text.strip() + # Strip trailing (Gemma4 turn-end marker) + if text.endswith(_TURN_END_TAG): + text = text[: -len(_TURN_END_TAG)].rstrip() + # Strip trailing if present + if text.endswith(""): + text = text[:-5].rstrip() + return text diff --git a/vllm/tool_parsers/__init__.py b/vllm/tool_parsers/__init__.py index f480a635c..bffa00c4e 100644 --- a/vllm/tool_parsers/__init__.py +++ b/vllm/tool_parsers/__init__.py @@ -154,6 +154,10 @@ _TOOL_PARSERS_TO_REGISTER = { "functiongemma_tool_parser", "FunctionGemmaToolParser", ), + "gemma4": ( + "gemma4_tool_parser", + "Gemma4ToolParser", + ), } diff --git a/vllm/tool_parsers/gemma4_tool_parser.py b/vllm/tool_parsers/gemma4_tool_parser.py new file mode 100644 index 000000000..c3d29f0ab --- /dev/null +++ b/vllm/tool_parsers/gemma4_tool_parser.py @@ -0,0 +1,724 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Tool call parser for Google Gemma4 models. + +Gemma4 uses a custom serialization format (not JSON) for tool calls:: + + <|tool_call>call:func_name{key:<|"|>value<|"|>,num:42} + +Strings are delimited by ``<|"|>`` (token 52), keys are unquoted, and +multiple tool calls are concatenated without separators. + +Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` are set. + +For offline inference tool call parsing (direct ``tokenizer.decode()`` output), +see ``vllm.tool_parsers.gemma4_utils.parse_tool_calls``. +""" + +import json +from collections.abc import Sequence + +import regex as re + +from vllm.entrypoints.chat_utils import make_tool_call_id +from vllm.entrypoints.openai.chat_completion.protocol import ( + ChatCompletionRequest, +) +from vllm.entrypoints.openai.engine.protocol import ( + DeltaFunctionCall, + DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, + ToolCall, +) +from vllm.entrypoints.openai.responses.protocol import ( + ResponsesRequest, +) +from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike +from vllm.tool_parsers.abstract_tool_parser import ToolParser +from vllm.tool_parsers.utils import find_common_prefix + +logger = init_logger(__name__) + +# Gemma4 special tokens for tool calls +TOOL_CALL_START = "<|tool_call>" +TOOL_CALL_END = "" +STRING_DELIM = '<|"|>' + + +# --------------------------------------------------------------------------- +# Gemma4 argument parser (used by both streaming and non-streaming paths) +# --------------------------------------------------------------------------- + + +def _parse_gemma4_value(value_str: str) -> object: + """Parse a single Gemma4 value (after key:) into a Python object.""" + value_str = value_str.strip() + if not value_str: + return value_str + + # Boolean + if value_str == "true": + return True + if value_str == "false": + return False + + # Number (int or float) + try: + if "." in value_str: + return float(value_str) + return int(value_str) + except ValueError: + pass + + # Bare string (no <|"|> delimiters — shouldn't happen but be safe) + return value_str + + +def _parse_gemma4_args(args_str: str) -> dict: + """Parse Gemma4's custom key:value format into a Python dict. + + Format examples:: + + location:<|"|>Tokyo<|"|> + location:<|"|>San Francisco<|"|>,unit:<|"|>celsius<|"|> + count:42,flag:true + nested:{inner_key:<|"|>val<|"|>} + items:[<|"|>a<|"|>,<|"|>b<|"|>] + + Returns a dict ready for ``json.dumps()``. + """ + if not args_str or not args_str.strip(): + return {} + + result: dict = {} + i = 0 + n = len(args_str) + + while i < n: + # Skip whitespace and commas + while i < n and args_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # Parse key (unquoted, ends at ':') + key_start = i + while i < n and args_str[i] != ":": + i += 1 + if i >= n: + break + key = args_str[key_start:i].strip() + i += 1 # skip ':' + + # Parse value + if i >= n: + result[key] = "" + break + + # Skip whitespace after ':' + while i < n and args_str[i] in (" ", "\n", "\t"): + i += 1 + if i >= n: + result[key] = "" + break + + # String value: <|"|>...<|"|> + if args_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + val_start = i + end_pos = args_str.find(STRING_DELIM, i) + if end_pos == -1: + # Unterminated string — take rest + result[key] = args_str[val_start:] + break + result[key] = args_str[val_start:end_pos] + i = end_pos + len(STRING_DELIM) + + # Nested object: {...} + elif args_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i:].startswith(STRING_DELIM): + # Skip over string contents to avoid counting { inside strings + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + i = n if next_delim == -1 else next_delim + len(STRING_DELIM) + continue + if args_str[i] == "{": + depth += 1 + elif args_str[i] == "}": + depth -= 1 + i += 1 + result[key] = _parse_gemma4_args(args_str[obj_start : i - 1]) + + # Array: [...] + elif args_str[i] == "[": + depth = 1 + arr_start = i + 1 + i += 1 + while i < n and depth > 0: + if args_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + next_delim = args_str.find(STRING_DELIM, i) + i = n if next_delim == -1 else next_delim + len(STRING_DELIM) + continue + if args_str[i] == "[": + depth += 1 + elif args_str[i] == "]": + depth -= 1 + i += 1 + arr_content = args_str[arr_start : i - 1] + result[key] = _parse_gemma4_array(arr_content) + + # Bare value (number, boolean, etc.) + else: + val_start = i + while i < n and args_str[i] not in (",", "}", "]"): + i += 1 + result[key] = _parse_gemma4_value(args_str[val_start:i]) + + return result + + +def _parse_gemma4_array(arr_str: str) -> list: + """Parse a Gemma4 array content string into a Python list.""" + items: list = [] + i = 0 + n = len(arr_str) + + while i < n: + while i < n and arr_str[i] in (" ", ",", "\n", "\t"): + i += 1 + if i >= n: + break + + # String element + if arr_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + end_pos = arr_str.find(STRING_DELIM, i) + if end_pos == -1: + items.append(arr_str[i:]) + break + items.append(arr_str[i:end_pos]) + i = end_pos + len(STRING_DELIM) + + # Nested object + elif arr_str[i] == "{": + depth = 1 + obj_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i:].startswith(STRING_DELIM): + i += len(STRING_DELIM) + nd = arr_str.find(STRING_DELIM, i) + i = nd + len(STRING_DELIM) if nd != -1 else n + continue + if arr_str[i] == "{": + depth += 1 + elif arr_str[i] == "}": + depth -= 1 + i += 1 + items.append(_parse_gemma4_args(arr_str[obj_start : i - 1])) + + # Nested array + elif arr_str[i] == "[": + depth = 1 + sub_start = i + 1 + i += 1 + while i < n and depth > 0: + if arr_str[i] == "[": + depth += 1 + elif arr_str[i] == "]": + depth -= 1 + i += 1 + items.append(_parse_gemma4_array(arr_str[sub_start : i - 1])) + + # Bare value + else: + val_start = i + while i < n and arr_str[i] not in (",", "]"): + i += 1 + items.append(_parse_gemma4_value(arr_str[val_start:i])) + + return items + + +# --------------------------------------------------------------------------- +# Parser +# --------------------------------------------------------------------------- + + +class Gemma4ToolParser(ToolParser): + """ + Tool call parser for Google Gemma4 models. + + Handles the Gemma4 function call format:: + + <|tool_call>call:func_name{key:<|"|>value<|"|>} + + Used when ``--enable-auto-tool-choice --tool-call-parser gemma4`` + are set. + + Streaming strategy: **accumulate-then-parse-then-diff** + + Instead of trying to convert Gemma4's custom format to JSON + token-by-token (which fails because Gemma4 uses bare keys, custom + delimiters, and structural braces that differ from JSON), this parser: + + 1. Accumulates the raw Gemma4 argument string during streaming + 2. Parses it with ``_parse_gemma4_args()`` into a Python dict + 3. Converts to JSON with ``json.dumps()`` + 4. Diffs against the previously-streamed JSON string + 5. Emits only the new JSON fragment as the delta + + This follows the same pattern used by FunctionGemma, Hermes, and Llama + tool parsers. + """ + + def __init__(self, tokenizer: TokenizerLike): + super().__init__(tokenizer) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction." + ) + + # Token strings + self.tool_call_start_token = TOOL_CALL_START + self.tool_call_end_token = TOOL_CALL_END + + # Token IDs + self.tool_call_start_token_id = self.vocab.get(TOOL_CALL_START) + self.tool_call_end_token_id = self.vocab.get(TOOL_CALL_END) + + if self.tool_call_start_token_id is None: + raise RuntimeError( + "Gemma4 ToolParser could not locate the tool call start " + f"token '{TOOL_CALL_START}' in the tokenizer!" + ) + + # Regex for non-streaming: extract complete tool calls. + # Supports function names with letters, digits, underscores, + # hyphens, and dots (e.g. "get-weather", "module.func"). + self.tool_call_regex = re.compile( + r"<\|tool_call>call:([\w\-\.]+)\{(.*?)\}", + re.DOTALL, + ) + + # Streaming state — reset per-request via _reset_streaming_state() + self._reset_streaming_state() + + # Delta buffer for handling multi-token special sequences + self.buffered_delta_text = "" + + def _reset_streaming_state(self) -> None: + """Reset all streaming state for a new request.""" + self.current_tool_id = -1 + self.current_tool_name_sent = False + self.prev_tool_call_arr: list[dict] = [] + self.streamed_args_for_tool: list[str] = [] + + def adjust_request( + self, request: ChatCompletionRequest | ResponsesRequest + ) -> ChatCompletionRequest | ResponsesRequest: + request = super().adjust_request(request) + if ( + isinstance(request, ChatCompletionRequest) + and request.tools + and request.tool_choice != "none" + ): + # Don't skip special tokens — <|tool_call> etc. are needed + request.skip_special_tokens = False + return request + + # ------------------------------------------------------------------ + # Delta buffering for multi-token special sequences + # ------------------------------------------------------------------ + + def _buffer_delta_text(self, delta_text: str) -> str: + """Buffer incoming delta text to handle multi-token special sequences. + + Accumulates partial tokens that could be the start of + ``<|tool_call>`` or ```` and only flushes them + when the complete sequence is recognized or the sequence breaks. + + This prevents partial special tokens (e.g., ``<|tool``) from being + emitted prematurely as content text. + """ + combined = self.buffered_delta_text + delta_text + + # Check if combined ends with a complete special token + if combined.endswith(TOOL_CALL_START) or combined.endswith(TOOL_CALL_END): + self.buffered_delta_text = "" + return combined + + # Check if combined ends with a partial prefix of a special token + for tag in [TOOL_CALL_START, TOOL_CALL_END]: + for i in range(1, len(tag)): + if combined.endswith(tag[:i]): + self.buffered_delta_text = combined[-i:] + return combined[:-i] + + # No partial match — flush everything + self.buffered_delta_text = "" + return combined + + # ------------------------------------------------------------------ + # Non-streaming extraction + # ------------------------------------------------------------------ + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + matches = self.tool_call_regex.findall(model_output) + if not matches: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + tool_calls: list[ToolCall] = [] + for func_name, args_str in matches: + arguments = _parse_gemma4_args(args_str) + tool_calls.append( + ToolCall( + type="function", + function=FunctionCall( + name=func_name, + arguments=json.dumps(arguments, ensure_ascii=False), + ), + ) + ) + + # Content = text before first tool call (if any) + content_end = model_output.find(self.tool_call_start_token) + content = model_output[:content_end].strip() if content_end > 0 else None + + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None, + ) + + except Exception: + logger.exception("Error extracting tool calls from Gemma4 response") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + # ------------------------------------------------------------------ + # Streaming extraction — accumulate-then-parse-then-diff + # ------------------------------------------------------------------ + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> DeltaMessage | None: + # Buffer delta text to handle multi-token special sequences + delta_text = self._buffer_delta_text(delta_text) + # Reconstruct current_text after buffering to stay in sync + current_text = previous_text + delta_text + + # If no tool call token seen yet, emit as content + if self.tool_call_start_token not in current_text: + if delta_text: + return DeltaMessage(content=delta_text) + return None + + try: + return self._extract_streaming( + previous_text=previous_text, + current_text=current_text, + delta_text=delta_text, + ) + except Exception: + logger.exception("Error in Gemma4 streaming tool call extraction") + return None + + def _extract_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + ) -> DeltaMessage | None: + """Tag-counting streaming parser. + + Uses the proven approach from FunctionGemma/Hermes: count start/end + tags in previous vs current text to determine phase, then + accumulate-parse-diff for arguments. + + Format: ``<|tool_call>call:name{args}`` + """ + start_count = current_text.count(self.tool_call_start_token) + end_count = current_text.count(self.tool_call_end_token) + prev_start_count = previous_text.count(self.tool_call_start_token) + prev_end_count = previous_text.count(self.tool_call_end_token) + + # Case 1: Not inside any tool call — emit as content + if ( + start_count == end_count + and prev_end_count == end_count + and self.tool_call_end_token not in delta_text + ): + if delta_text: + return DeltaMessage(content=delta_text) + return None + + # Case 2: Starting a new tool call + if start_count > prev_start_count and start_count > end_count: + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + self.prev_tool_call_arr.append({}) + logger.debug("Starting new tool call %d", self.current_tool_id) + # Don't return yet — fall through to try parsing if there's + # content after <|tool_call> in this same delta + # (but usually it's just the token itself, so return None) + if len(delta_text) <= len(self.tool_call_start_token): + return None + + # Case 3: Tool call just ended + if end_count > prev_end_count: + return self._handle_tool_call_end(current_text) + + # Case 4: In the middle of a tool call — parse partial content + if start_count > end_count: + return self._handle_tool_call_middle(current_text) + + # Default: generate text outside tool calls + if delta_text: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + if text: + return DeltaMessage(content=text) + return None + + def _extract_partial_call(self, current_text: str) -> tuple[str | None, str]: + """Extract function name and raw argument string from partial text. + + Returns (func_name, raw_args_str) or (None, "") if not parseable yet. + """ + # Get the text after the last <|tool_call> token + last_start = current_text.rfind(self.tool_call_start_token) + if last_start == -1: + return None, "" + + partial_call = current_text[last_start + len(self.tool_call_start_token) :] + + # Strip end token if present + if self.tool_call_end_token in partial_call: + partial_call = partial_call.split(self.tool_call_end_token)[0] + + # Expect "call:name{args...}" or "call:name{args...}" + if not partial_call.startswith("call:"): + return None, "" + + func_part = partial_call[5:] # skip "call:" + + if "{" not in func_part: + # Still accumulating function name, not ready yet + return None, "" + + func_name, _, args_part = func_part.partition("{") + func_name = func_name.strip() + + # Strip trailing '}' if present (Gemma4 structural brace) + if args_part.endswith("}"): + args_part = args_part[:-1] + + return func_name, args_part + + def _handle_tool_call_middle(self, current_text: str) -> DeltaMessage | None: + """Handle streaming when we're inside an active tool call. + + Accumulates the raw Gemma4 arguments, parses them into JSON, and + diffs against the previously-streamed JSON to emit only the new + fragment. + """ + func_name, args_part = self._extract_partial_call(current_text) + + if func_name is None: + return None + + # Step 1: Send function name (once) + if not self.current_tool_name_sent and func_name: + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = { + "name": func_name, + "arguments": {}, + } + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + type="function", + id=make_tool_call_id(), + function=DeltaFunctionCall( + name=func_name, + arguments="", + ).model_dump(exclude_none=True), + ) + ] + ) + + # Step 2: Parse and diff arguments + if self.current_tool_name_sent and args_part: + return self._emit_argument_diff(args_part) + + return None + + def _handle_tool_call_end(self, current_text: str) -> DeltaMessage | None: + """Handle streaming when a tool call has just completed. + + Performs a final parse of the complete tool call and flushes + any remaining un-streamed argument fragments. + """ + if self.current_tool_id < 0 or self.current_tool_id >= len( + self.prev_tool_call_arr + ): + logger.debug( + "Tool call end detected but no active tool call (current_tool_id=%d)", + self.current_tool_id, + ) + return None + + # Parse the complete tool call using regex for accuracy + all_matches = self.tool_call_regex.findall(current_text) + if self.current_tool_id < len(all_matches): + _, args_str = all_matches[self.current_tool_id] + final_args = _parse_gemma4_args(args_str) + final_args_json = json.dumps(final_args, ensure_ascii=False) + + prev_streamed = self.streamed_args_for_tool[self.current_tool_id] + if len(final_args_json) > len(prev_streamed): + diff = final_args_json[len(prev_streamed) :] + self.streamed_args_for_tool[self.current_tool_id] = final_args_json + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = final_args + + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) + + return None + + def _emit_argument_diff(self, raw_args_str: str) -> DeltaMessage | None: + """Parse raw Gemma4 arguments, convert to JSON, diff, and emit. + + This is the core of the accumulate-then-parse-then-diff strategy: + 1. Parse ``raw_args_str`` with ``_parse_gemma4_args()`` + 2. Convert to JSON string with ``json.dumps()`` + 3. Withhold trailing closing characters (``"}``) that may move + as more tokens arrive + 4. Diff against previously streamed JSON and emit only new chars + + **Why withholding is necessary:** + + Gemma4's custom format produces *structurally incomplete* JSON + during streaming. For example, when ``<|"|>Paris`` arrives + without a closing delimiter, ``_parse_gemma4_args`` treats it + as a complete value and produces ``{"location": "Paris"}``. But + when ``, France<|"|>`` arrives next, the JSON becomes + ``{"location": "Paris, France"}``. If we had sent the closing + ``"}`` from the first parse, the concatenated client output + would be ``{"location": "Paris"}France"}``, which is garbage. + + The solution: **never send trailing closing chars during + streaming**. They get flushed by ``_handle_tool_call_end()`` + when the ```` end marker arrives. + + Args: + raw_args_str: The raw Gemma4 argument text accumulated so far + (without the surrounding ``{`` ``}``). + + Returns: + DeltaMessage with the argument diff, or None if no new content. + """ + try: + current_args = _parse_gemma4_args(raw_args_str) + except Exception: + logger.debug( + "Could not parse partial Gemma4 args yet: %s", + raw_args_str[:100], + ) + return None + + if not current_args: + return None + + current_args_json = json.dumps(current_args, ensure_ascii=False) + + # Withhold trailing closing characters that may shift as more + # tokens arrive. Strip trailing '}', '"', and ']' sequences + # to get the "safe prefix". + safe_json = current_args_json + while safe_json and safe_json[-1] in ("}", '"', "]"): + safe_json = safe_json[:-1] + + prev_streamed = self.streamed_args_for_tool[self.current_tool_id] + + if not safe_json or safe_json == prev_streamed: + return None + + # Use find_common_prefix to handle cases where the value changed + # structurally (e.g., a string grew). + if prev_streamed: + prefix = find_common_prefix(prev_streamed, safe_json) + sent_len = len(prev_streamed) + prefix_len = len(prefix) + + if prefix_len < sent_len: + # Structure changed — we sent too much. Truncate our + # tracking to the common prefix and wait for the final + # flush in _handle_tool_call_end. + self.streamed_args_for_tool[self.current_tool_id] = prefix + return None + + # Stream the new stable portion + diff = safe_json[sent_len:] + else: + # First emission + diff = safe_json + + if diff: + self.streamed_args_for_tool[self.current_tool_id] = safe_json + self.prev_tool_call_arr[self.current_tool_id]["arguments"] = current_args + + return DeltaMessage( + tool_calls=[ + DeltaToolCall( + index=self.current_tool_id, + function=DeltaFunctionCall(arguments=diff).model_dump( + exclude_none=True + ), + ) + ] + ) + + return None diff --git a/vllm/tool_parsers/gemma4_utils.py b/vllm/tool_parsers/gemma4_utils.py new file mode 100644 index 000000000..439ad1125 --- /dev/null +++ b/vllm/tool_parsers/gemma4_utils.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. + +"""Gemma4 tool call parsing utilities for offline inference. + +Standalone functions that parse decoded model text to extract tool calls +from Gemma4 models. These are pure-Python utilities with zero heavy +dependencies — they work on raw decoded strings from any inference +backend (vLLM, HuggingFace, TGI, etc.). + +For the OpenAI-compatible API server tool parser (streaming + +non-streaming), see ``vllm.tool_parsers.gemma4_tool_parser``. +For thinking/reasoning output parsing, see +``vllm.reasoning.gemma4_utils``. + +Usage with vLLM offline inference:: + + from vllm import LLM, SamplingParams + from vllm.tool_parsers.gemma4_utils import ( + parse_tool_calls, + has_tool_response_tag, + ) + + llm = LLM(model="google/gemma-4-it") + outputs = llm.generate(prompt, SamplingParams(...)) + text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False) + + # Extract tool calls + tool_calls = parse_tool_calls(text) + for tc in tool_calls: + print(f"{tc['name']}({tc['arguments']})") + +Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users +do not need a transformers dependency for output parsing. +""" + +import json + +import regex as re + +# Tool call delimiter tokens as they appear in decoded text. +# Standard format: <|tool_call>call:name{args} +_TOOL_CALL_START_TAG = "<|tool_call>" +_TOOL_CALL_END_TAG = "" +_TOOL_RESPONSE_START_TAG = "<|tool_response>" + +# Gemma4 escape token as it appears in decoded text. +_ESCAPE_TOKEN = '<|"|>' + + +def _parse_tool_arguments(args_str: str) -> dict[str, str]: + """Parse tool call arguments from the Gemma4 compact format. + + Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback + to heuristic key-value extraction. Also tolerates the slightly different + ``key: "value"`` format (space + plain quotes) that some chat templates + produce. + + Args: + args_str: Raw argument string from inside ``call:name{...}``. + + Returns: + Dictionary of argument name → value. + """ + if not args_str or not args_str.strip(): + return {} + + # Replace Gemma4 escape tokens with standard quotes. + cleaned = args_str.replace(_ESCAPE_TOKEN, '"') + + # Try JSON parsing first (handles nested values, arrays, etc.). + try: + parsed = json.loads("{" + cleaned + "}") + # Ensure all values are strings for consistency. + return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()} + except (json.JSONDecodeError, ValueError): + pass + + # Fallback: extract key:"value" pairs (allow optional space after colon). + arguments = {} + for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned): + arguments[key] = value + + if not arguments: + # Last resort: extract key:value pairs (unquoted). + for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str): + arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "") + + return arguments + + +def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]: + """Parse tool calls from decoded Gemma4 model output. + + Uses a tiered parsing strategy to handle known output variations in + Gemma4 models, which may emit + non-standard tool call formats. + + Parsing tiers: + 1. **Standard**: ``<|tool_call>call:name{args}`` + (special token IDs 48/49 in decoded text) + 2. **Fallback** (when ``strict=False``): bare ``call:name{args}`` + patterns, including ``name{args}`` (fragmented tokens from + multimodal inputs) + + Args: + text: Decoded model output text (from ``tokenizer.decode(..., + skip_special_tokens=False)``). + strict: If ``True``, only match the standard ``<|tool_call>`` format. + If ``False`` (default), also try fallback patterns for + known Gemma4 output variations. + + Returns: + A list of dicts, each with keys: + - ``"name"``: The tool function name (e.g. ``"get_weather"``). + - ``"arguments"``: A dict of argument name → value. + + Example:: + + >>> from vllm.tool_parsers.gemma4_utils import parse_tool_calls + >>> output = tokenizer.decode(outputs[0], skip_special_tokens=False) + >>> tool_calls = parse_tool_calls(output) + >>> for tc in tool_calls: + ... print(f"Call: {tc['name']}({tc['arguments']})") + """ + results = [] + + # Tier 1: Standard format with special tokens. + # <|tool_call>call:name{args} + # Note: Some Gemma4 models emit instead of . + standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:|)" + for match in re.finditer(standard_pattern, text, re.DOTALL): + name, args_str = match.group(1), match.group(2) + results.append( + { + "name": name, + "arguments": _parse_tool_arguments(args_str), + } + ) + + if results or strict: + return results + + # Tier 2: Fallback for known Gemma4 output variations. + # Matches: name{args}, call:name{args}, or bare call:name{args} + fallback_pattern = r"(?:|(?:^|\s)call:)(\w+)\{(.*?)\}" + for match in re.finditer(fallback_pattern, text, re.DOTALL): + name, args_str = match.group(1), match.group(2) + results.append( + { + "name": name, + "arguments": _parse_tool_arguments(args_str), + } + ) + + return results + + +def has_tool_response_tag(text: str) -> bool: + """Check if model output properly ends with a tool response tag. + + Some Gemma4 models sometimes emit ```` instead of + ``<|tool_response>`` after a tool call. This helper detects + whether the model used the proper termination, so callers can + decide whether to inject ``<|tool_response>`` into the next prompt. + + Args: + text: Decoded model output text. + + Returns: + ``True`` if the output ends with ``<|tool_response>`` + (proper behavior), ``False`` otherwise. + + Example:: + + >>> from vllm.tool_parsers.gemma4_utils import has_tool_response_tag + >>> if not has_tool_response_tag(model_output): + ... # Model used instead — inject <|tool_response> manually + ... next_prompt = "<|tool_response>" + tool_result + """ + stripped = text.rstrip() + return stripped.endswith(_TOOL_RESPONSE_START_TAG) diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index 3229539e3..4529baf29 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -448,6 +448,16 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): return getattr(self.hf_text_config, "num_nextn_predict_layers", 1) +class Gemma4ModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_head_size(self) -> int: + # Gemma4 uses dual head dimensions: head_dim (sliding attention) + # and global_head_dim (full attention). Return the largest so + # that attention backends allocate buffers large enough for both. + head_dim = getattr(self.hf_text_config, "head_dim", 0) + global_head_dim = getattr(self.hf_text_config, "global_head_dim", 0) + return max(head_dim, global_head_dim) or super().get_head_size() + + # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { "cohere_asr": CohereAsrModelArchConfigConvertor, @@ -471,4 +481,6 @@ MODEL_ARCH_CONFIG_CONVERTORS = { "ernie_mtp": ErnieMTPModelArchConfigConvertor, "pangu_ultra_moe_mtp": PanguUltraMoeMTPModelArchConfigConvertor, "longcat_flash_mtp": LongCatFlashMTPModelArchConfigConvertor, + "gemma4": Gemma4ModelArchConfigConvertor, + "gemma4_text": Gemma4ModelArchConfigConvertor, }