feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Luciano Martins <lucianomartins@google.com>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Luciano Martins
2026-04-02 15:13:28 -03:00
committed by GitHub
parent ecd5443dbc
commit 08ed2b9688
20 changed files with 5051 additions and 1 deletions

View File

@@ -9,6 +9,7 @@ from vllm.utils.math_utils import round_up
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
logger = init_logger(__name__)
@@ -52,6 +53,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
hf_config.is_causal = not hf_config.use_bidirectional_attention
class Gemma4Config(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""Force unified attention backend for models with heterogeneous
head dimensions.
Some Gemma4 variants use different head dimensions for
sliding window (head_dim) vs full attention (global_head_dim) layers.
When global_head_dim > 256, FlashAttention rejects those layers
(head_size <= 256 kernel limit), causing vLLM to select a different
backend for each layer type. This mixed-backend execution produces
numerical divergence and output corruption.
The fix detects heterogeneous head dimensions from the model config
and forces TRITON_ATTN (which has no head_size ceiling) for all
layers when the user hasn't explicitly chosen a backend.
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
require NixlConnector changes to support per-layer KV transfer
with different head dimensions for prefill-decode disaggregation.
"""
hf_text_config = vllm_config.model_config.hf_text_config
head_dim = getattr(hf_text_config, "head_dim", None)
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
# Only force Triton when head dimensions actually differ AND the
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
# This avoids unnecessary backend forcing on smaller models where
# the config carries global_head_dim but all layers can still use
# the same FA backend.
max_head_dim = max(head_dim or 0, global_head_dim or 0)
if (
head_dim is not None
and global_head_dim is not None
and head_dim != global_head_dim
and max_head_dim > 256
and vllm_config.attention_config.backend is None
):
from vllm.v1.attention.backends.registry import (
AttentionBackendEnum,
)
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
logger.info(
"Gemma4 model has heterogeneous head dimensions "
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
"backend to prevent mixed-backend numerical divergence.",
head_dim,
global_head_dim,
)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
@@ -533,6 +586,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig,
"Gemma3TextModel": Gemma3TextModelConfig,
"Gemma4ForCausalLM": Gemma4Config,
"Gemma4ForConditionalGeneration": Gemma4Config,
"GptOssForCausalLM": GptOssForCausalLMConfig,
"GteModel": SnowflakeGteNewModelConfig,
"GteNewForSequenceClassification": GteNewModelConfig,

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,292 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
"""Gemma4 output parsing utilities for offline inference.
Standalone functions that parse decoded model text to extract structured
thinking content and tool calls from Gemma4 models. These are pure-Python
utilities with zero heavy dependencies — they work on raw decoded strings
from any inference backend (vLLM, HuggingFace, TGI, etc.).
Usage with vLLM offline inference::
from vllm import LLM, SamplingParams
from vllm.model_executor.models.gemma4_utils import (
parse_output,
parse_tool_calls,
)
llm = LLM(model="google/gemma-4-it")
outputs = llm.generate(prompt, SamplingParams(...))
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
# Extract thinking / answer (works with or without enable_thinking)
result = parse_output(text)
print(result["thinking"]) # chain-of-thought or None
print(result["answer"]) # final answer
# Extract tool calls
tool_calls = parse_tool_calls(text)
for tc in tool_calls:
print(f"{tc['name']}({tc['arguments']})")
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
do not need a transformers dependency for output parsing.
"""
import json
import regex as re
# ---- Thinking Mode Utility ----
# Thinking delimiter tokens as they appear in decoded text.
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
_THINKING_START_TAG = "<|channel>"
_THINKING_END_TAG = "<channel|>"
# Sentinel tokens that may appear in decoded output.
_TURN_END_TAG = "<turn|>"
def parse_thinking_output(text: str) -> dict[str, str | None]:
"""Parse decoded Gemma4 model output.
Use this on **all** Gemma4 output regardless of whether thinking mode
was enabled. It handles three cases:
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
``<channel|>`` to separate chain-of-thought from the answer and
strips the ``thought\\n`` role label.
2. **Thinking disabled, spurious label** — strips the bare
``thought\\n`` prefix that some Gemma4 models emit even
without thinking mode.
3. **Clean output** — returns the text unchanged.
The answer text is always cleaned of trailing sentinel tokens
(``<turn|>``, ``<eos>``, etc.).
Args:
text: Decoded model output text (from ``tokenizer.decode(...)``).
Returns:
A dict with keys:
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
thinking delimiters were found.
- ``"answer"``: The final answer text.
Example::
>>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> result = parse_thinking_output(output_text)
>>> print(result["thinking"]) # chain-of-thought reasoning or None
>>> print(result["answer"]) # final answer
"""
if _THINKING_END_TAG in text:
parts = text.split(_THINKING_END_TAG, 1)
thinking_block = parts[0]
answer = _clean_answer(parts[1])
# Extract thinking content: strip the start tag if present
if _THINKING_START_TAG in thinking_block:
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
else:
thinking = thinking_block
# Strip the "thought\n" channel role label the model emits inside
# <|channel>thought\n...<channel|> (analogous to "user\n" in
# <|turn>user\n...<turn|>).
thinking = _strip_thought_label(thinking.strip())
thinking = thinking.strip()
return {"thinking": thinking, "answer": answer}
# No thinking delimiters found.
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
# emit even without thinking mode enabled, then clean trailing tokens.
answer = _strip_thought_label(text)
answer = _clean_answer(answer)
return {"thinking": None, "answer": answer}
def _strip_thought_label(text: str) -> str:
"""Strip the spurious ``thought\\n`` label from the start of text.
Only strips when ``thought`` appears as the very first word followed by
a newline — preserving the word ``thought`` in any other context.
"""
if text.startswith("thought\n"):
return text[len("thought\n") :]
return text
def _clean_answer(text: str) -> str:
"""Clean trailing sentinel tokens from the answer text.
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
model appends at the end of its response.
"""
text = text.strip()
# Strip trailing <turn|> (Gemma4 turn-end marker)
if text.endswith(_TURN_END_TAG):
text = text[: -len(_TURN_END_TAG)].rstrip()
# Strip trailing <eos> if present
if text.endswith("<eos>"):
text = text[:-5].rstrip()
return text
# ---- Tool Call Parsing Utility ----
#
# NOTE: For the OpenAI-compatible API server tool parser (streaming +
# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py.
# This module provides offline inference utilities for direct user import.
# Tool call delimiter tokens as they appear in decoded text.
# Standard format: <|tool_call>call:name{args}<tool_call|>
_TOOL_CALL_START_TAG = "<|tool_call>"
_TOOL_CALL_END_TAG = "<tool_call|>"
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
# Gemma4 escape token as it appears in decoded text.
_ESCAPE_TOKEN = '<|"|>'
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
"""Parse tool call arguments from the Gemma4 compact format.
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
to heuristic key-value extraction. Also tolerates the slightly different
``key: "value"`` format (space + plain quotes) that some chat templates
produce.
Args:
args_str: Raw argument string from inside ``call:name{...}``.
Returns:
Dictionary of argument name → value.
"""
if not args_str or not args_str.strip():
return {}
# Replace Gemma4 escape tokens with standard quotes.
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
# Try JSON parsing first (handles nested values, arrays, etc.).
try:
parsed = json.loads("{" + cleaned + "}")
# Ensure all values are strings for consistency.
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
except (json.JSONDecodeError, ValueError):
pass
# Fallback: extract key:"value" pairs (allow optional space after colon).
arguments = {}
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
arguments[key] = value
if not arguments:
# Last resort: extract key:value pairs (unquoted).
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
return arguments
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
"""Parse tool calls from decoded Gemma4 model output.
Uses a tiered parsing strategy to handle known output variations in
Gemma4 models, which may emit
non-standard tool call formats.
Parsing tiers:
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
(special token IDs 48/49 in decoded text)
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
patterns, including ``<call>name{args}`` (fragmented tokens from
multimodal inputs)
Args:
text: Decoded model output text (from ``tokenizer.decode(...,
skip_special_tokens=False)``).
strict: If ``True``, only match the standard ``<|tool_call>`` format.
If ``False`` (default), also try fallback patterns for
known Gemma4 output variations.
Returns:
A list of dicts, each with keys:
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
- ``"arguments"``: A dict of argument name → value.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... parse_tool_calls
... )
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
>>> tool_calls = parse_tool_calls(output)
>>> for tc in tool_calls:
... print(f"Call: {tc['name']}({tc['arguments']})")
"""
results = []
# Tier 1: Standard format with special tokens.
# <|tool_call>call:name{args}<tool_call|>
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
for match in re.finditer(standard_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
if results or strict:
return results
# Tier 2: Fallback for known Gemma4 output variations.
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
for match in re.finditer(fallback_pattern, text, re.DOTALL):
name, args_str = match.group(1), match.group(2)
results.append(
{
"name": name,
"arguments": _parse_tool_arguments(args_str),
}
)
return results
def has_tool_response_tag(text: str) -> bool:
"""Check if model output properly ends with a tool response tag.
Some Gemma4 models sometimes emit ``<eos>`` instead of
``<|tool_response>`` after a tool call. This helper detects
whether the model used the proper termination, so callers can
decide whether to inject ``<|tool_response>`` into the next prompt.
Args:
text: Decoded model output text.
Returns:
``True`` if the output ends with ``<|tool_response>``
(proper behavior), ``False`` otherwise.
Example::
>>> from vllm.model_executor.models.gemma4_utils import (
... has_tool_response_tag
... )
>>> if not has_tool_response_tag(model_output):
... # Model used <eos> instead — inject <|tool_response> manually
... next_prompt = "<|tool_response>" + tool_result
"""
stripped = text.rstrip()
return stripped.endswith(_TOOL_RESPONSE_START_TAG)

View File

@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
@@ -383,6 +384,7 @@ _MULTIMODAL_MODELS = {
"gemma3n_mm",
"Gemma3nForConditionalGeneration",
),
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),

View File

@@ -233,8 +233,15 @@ class AutoWeightsLoader:
):
"""
Add tensor names that are not in the model params that may be in the
safetensors, e.g., batch normalization stats.
safetensors, e.g., batch normalization stats and registered buffers.
"""
# Add persistent registered buffers.
# Non-persistent buffers are excluded, matching PyTorch state_dict().
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
for buf_name, buf in module.named_buffers(recurse=False):
if buf_name not in child_params and buf_name not in non_persistent:
child_params[buf_name] = buf
if isinstance(
module,
(