feat(models): implement Google Gemma 4 architecture support (MoE, Multimodal, Reasoning, Tool-Use) (#38826)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Luciano Martins <lucianomartins@google.com> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.utils.math_utils import round_up
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,6 +53,58 @@ class Gemma3TextModelConfig(VerifyAndUpdateConfig):
|
||||
hf_config.is_causal = not hf_config.use_bidirectional_attention
|
||||
|
||||
|
||||
class Gemma4Config(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
"""Force unified attention backend for models with heterogeneous
|
||||
head dimensions.
|
||||
|
||||
Some Gemma4 variants use different head dimensions for
|
||||
sliding window (head_dim) vs full attention (global_head_dim) layers.
|
||||
When global_head_dim > 256, FlashAttention rejects those layers
|
||||
(head_size <= 256 kernel limit), causing vLLM to select a different
|
||||
backend for each layer type. This mixed-backend execution produces
|
||||
numerical divergence and output corruption.
|
||||
|
||||
The fix detects heterogeneous head dimensions from the model config
|
||||
and forces TRITON_ATTN (which has no head_size ceiling) for all
|
||||
layers when the user hasn't explicitly chosen a backend.
|
||||
|
||||
TODO: Heterogeneous head_sizes (head_dim != global_head_dim)
|
||||
require NixlConnector changes to support per-layer KV transfer
|
||||
with different head dimensions for prefill-decode disaggregation.
|
||||
"""
|
||||
hf_text_config = vllm_config.model_config.hf_text_config
|
||||
head_dim = getattr(hf_text_config, "head_dim", None)
|
||||
global_head_dim = getattr(hf_text_config, "global_head_dim", None)
|
||||
|
||||
# Only force Triton when head dimensions actually differ AND the
|
||||
# larger one exceeds FlashAttention's kernel limit (head_size <= 256).
|
||||
# This avoids unnecessary backend forcing on smaller models where
|
||||
# the config carries global_head_dim but all layers can still use
|
||||
# the same FA backend.
|
||||
max_head_dim = max(head_dim or 0, global_head_dim or 0)
|
||||
if (
|
||||
head_dim is not None
|
||||
and global_head_dim is not None
|
||||
and head_dim != global_head_dim
|
||||
and max_head_dim > 256
|
||||
and vllm_config.attention_config.backend is None
|
||||
):
|
||||
from vllm.v1.attention.backends.registry import (
|
||||
AttentionBackendEnum,
|
||||
)
|
||||
|
||||
vllm_config.attention_config.backend = AttentionBackendEnum.TRITON_ATTN
|
||||
logger.info(
|
||||
"Gemma4 model has heterogeneous head dimensions "
|
||||
"(head_dim=%d, global_head_dim=%d). Forcing TRITON_ATTN "
|
||||
"backend to prevent mixed-backend numerical divergence.",
|
||||
head_dim,
|
||||
global_head_dim,
|
||||
)
|
||||
|
||||
|
||||
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
|
||||
@staticmethod
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
@@ -533,6 +586,8 @@ MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
||||
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
|
||||
"FalconMambaForCausalLM": MambaModelConfig,
|
||||
"Gemma3TextModel": Gemma3TextModelConfig,
|
||||
"Gemma4ForCausalLM": Gemma4Config,
|
||||
"Gemma4ForConditionalGeneration": Gemma4Config,
|
||||
"GptOssForCausalLM": GptOssForCausalLMConfig,
|
||||
"GteModel": SnowflakeGteNewModelConfig,
|
||||
"GteNewForSequenceClassification": GteNewModelConfig,
|
||||
|
||||
1239
vllm/model_executor/models/gemma4.py
Normal file
1239
vllm/model_executor/models/gemma4.py
Normal file
File diff suppressed because it is too large
Load Diff
1341
vllm/model_executor/models/gemma4_mm.py
Normal file
1341
vllm/model_executor/models/gemma4_mm.py
Normal file
File diff suppressed because it is too large
Load Diff
292
vllm/model_executor/models/gemma4_utils.py
Normal file
292
vllm/model_executor/models/gemma4_utils.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved.
|
||||
|
||||
"""Gemma4 output parsing utilities for offline inference.
|
||||
|
||||
Standalone functions that parse decoded model text to extract structured
|
||||
thinking content and tool calls from Gemma4 models. These are pure-Python
|
||||
utilities with zero heavy dependencies — they work on raw decoded strings
|
||||
from any inference backend (vLLM, HuggingFace, TGI, etc.).
|
||||
|
||||
Usage with vLLM offline inference::
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.model_executor.models.gemma4_utils import (
|
||||
parse_output,
|
||||
parse_tool_calls,
|
||||
)
|
||||
|
||||
llm = LLM(model="google/gemma-4-it")
|
||||
outputs = llm.generate(prompt, SamplingParams(...))
|
||||
text = tokenizer.decode(outputs[0].outputs[0].token_ids, skip_special_tokens=False)
|
||||
|
||||
# Extract thinking / answer (works with or without enable_thinking)
|
||||
result = parse_output(text)
|
||||
print(result["thinking"]) # chain-of-thought or None
|
||||
print(result["answer"]) # final answer
|
||||
|
||||
# Extract tool calls
|
||||
tool_calls = parse_tool_calls(text)
|
||||
for tc in tool_calls:
|
||||
print(f"{tc['name']}({tc['arguments']})")
|
||||
|
||||
Ported from ``transformers.models.gemma4.utils_gemma4`` so that vLLM users
|
||||
do not need a transformers dependency for output parsing.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import regex as re
|
||||
|
||||
# ---- Thinking Mode Utility ----
|
||||
|
||||
# Thinking delimiter tokens as they appear in decoded text.
|
||||
# Gemma4 uses <|channel> (start) and <channel|> (end) as thinking delimiters.
|
||||
_THINKING_START_TAG = "<|channel>"
|
||||
_THINKING_END_TAG = "<channel|>"
|
||||
|
||||
# Sentinel tokens that may appear in decoded output.
|
||||
_TURN_END_TAG = "<turn|>"
|
||||
|
||||
|
||||
def parse_thinking_output(text: str) -> dict[str, str | None]:
|
||||
"""Parse decoded Gemma4 model output.
|
||||
|
||||
Use this on **all** Gemma4 output regardless of whether thinking mode
|
||||
was enabled. It handles three cases:
|
||||
|
||||
1. **Thinking enabled, tags present** — splits on ``<|channel>``/
|
||||
``<channel|>`` to separate chain-of-thought from the answer and
|
||||
strips the ``thought\\n`` role label.
|
||||
2. **Thinking disabled, spurious label** — strips the bare
|
||||
``thought\\n`` prefix that some Gemma4 models emit even
|
||||
without thinking mode.
|
||||
3. **Clean output** — returns the text unchanged.
|
||||
|
||||
The answer text is always cleaned of trailing sentinel tokens
|
||||
(``<turn|>``, ``<eos>``, etc.).
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...)``).
|
||||
|
||||
Returns:
|
||||
A dict with keys:
|
||||
- ``"thinking"``: The chain-of-thought text, or ``None`` if no
|
||||
thinking delimiters were found.
|
||||
- ``"answer"``: The final answer text.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import parse_thinking_output
|
||||
>>> output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> result = parse_thinking_output(output_text)
|
||||
>>> print(result["thinking"]) # chain-of-thought reasoning or None
|
||||
>>> print(result["answer"]) # final answer
|
||||
"""
|
||||
if _THINKING_END_TAG in text:
|
||||
parts = text.split(_THINKING_END_TAG, 1)
|
||||
thinking_block = parts[0]
|
||||
answer = _clean_answer(parts[1])
|
||||
|
||||
# Extract thinking content: strip the start tag if present
|
||||
if _THINKING_START_TAG in thinking_block:
|
||||
thinking = thinking_block.split(_THINKING_START_TAG, 1)[1]
|
||||
else:
|
||||
thinking = thinking_block
|
||||
|
||||
# Strip the "thought\n" channel role label the model emits inside
|
||||
# <|channel>thought\n...<channel|> (analogous to "user\n" in
|
||||
# <|turn>user\n...<turn|>).
|
||||
thinking = _strip_thought_label(thinking.strip())
|
||||
thinking = thinking.strip()
|
||||
|
||||
return {"thinking": thinking, "answer": answer}
|
||||
|
||||
# No thinking delimiters found.
|
||||
# Strip spurious "thought\n" role label that some Gemma4 models sometimes
|
||||
# emit even without thinking mode enabled, then clean trailing tokens.
|
||||
answer = _strip_thought_label(text)
|
||||
answer = _clean_answer(answer)
|
||||
return {"thinking": None, "answer": answer}
|
||||
|
||||
|
||||
def _strip_thought_label(text: str) -> str:
|
||||
"""Strip the spurious ``thought\\n`` label from the start of text.
|
||||
|
||||
Only strips when ``thought`` appears as the very first word followed by
|
||||
a newline — preserving the word ``thought`` in any other context.
|
||||
"""
|
||||
if text.startswith("thought\n"):
|
||||
return text[len("thought\n") :]
|
||||
return text
|
||||
|
||||
|
||||
def _clean_answer(text: str) -> str:
|
||||
"""Clean trailing sentinel tokens from the answer text.
|
||||
|
||||
Strips ``<turn|>``, ``<eos>``, and surrounding whitespace that the
|
||||
model appends at the end of its response.
|
||||
"""
|
||||
text = text.strip()
|
||||
# Strip trailing <turn|> (Gemma4 turn-end marker)
|
||||
if text.endswith(_TURN_END_TAG):
|
||||
text = text[: -len(_TURN_END_TAG)].rstrip()
|
||||
# Strip trailing <eos> if present
|
||||
if text.endswith("<eos>"):
|
||||
text = text[:-5].rstrip()
|
||||
return text
|
||||
|
||||
|
||||
# ---- Tool Call Parsing Utility ----
|
||||
#
|
||||
# NOTE: For the OpenAI-compatible API server tool parser (streaming +
|
||||
# non-streaming), see vllm/tool_parsers/gemma4_tool_parser.py.
|
||||
# This module provides offline inference utilities for direct user import.
|
||||
|
||||
# Tool call delimiter tokens as they appear in decoded text.
|
||||
# Standard format: <|tool_call>call:name{args}<tool_call|>
|
||||
_TOOL_CALL_START_TAG = "<|tool_call>"
|
||||
_TOOL_CALL_END_TAG = "<tool_call|>"
|
||||
_TOOL_RESPONSE_START_TAG = "<|tool_response>"
|
||||
|
||||
# Gemma4 escape token as it appears in decoded text.
|
||||
_ESCAPE_TOKEN = '<|"|>'
|
||||
|
||||
|
||||
def _parse_tool_arguments(args_str: str) -> dict[str, str]:
|
||||
"""Parse tool call arguments from the Gemma4 compact format.
|
||||
|
||||
Handles the ``key:<|"|>value<|"|>`` format used by Gemma4, with fallback
|
||||
to heuristic key-value extraction. Also tolerates the slightly different
|
||||
``key: "value"`` format (space + plain quotes) that some chat templates
|
||||
produce.
|
||||
|
||||
Args:
|
||||
args_str: Raw argument string from inside ``call:name{...}``.
|
||||
|
||||
Returns:
|
||||
Dictionary of argument name → value.
|
||||
"""
|
||||
if not args_str or not args_str.strip():
|
||||
return {}
|
||||
|
||||
# Replace Gemma4 escape tokens with standard quotes.
|
||||
cleaned = args_str.replace(_ESCAPE_TOKEN, '"')
|
||||
|
||||
# Try JSON parsing first (handles nested values, arrays, etc.).
|
||||
try:
|
||||
parsed = json.loads("{" + cleaned + "}")
|
||||
# Ensure all values are strings for consistency.
|
||||
return {k: str(v) if not isinstance(v, str) else v for k, v in parsed.items()}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# Fallback: extract key:"value" pairs (allow optional space after colon).
|
||||
arguments = {}
|
||||
for key, value in re.findall(r'(\w+):\s*"([^"]*)"', cleaned):
|
||||
arguments[key] = value
|
||||
|
||||
if not arguments:
|
||||
# Last resort: extract key:value pairs (unquoted).
|
||||
for key, value in re.findall(r"(\w+):\s*([^,}]+)", args_str):
|
||||
arguments[key] = value.strip().strip('"').replace(_ESCAPE_TOKEN, "")
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_tool_calls(text: str, *, strict: bool = False) -> list[dict]:
|
||||
"""Parse tool calls from decoded Gemma4 model output.
|
||||
|
||||
Uses a tiered parsing strategy to handle known output variations in
|
||||
Gemma4 models, which may emit
|
||||
non-standard tool call formats.
|
||||
|
||||
Parsing tiers:
|
||||
1. **Standard**: ``<|tool_call>call:name{args}<tool_call|>``
|
||||
(special token IDs 48/49 in decoded text)
|
||||
2. **Fallback** (when ``strict=False``): bare ``call:name{args}``
|
||||
patterns, including ``<call>name{args}`` (fragmented tokens from
|
||||
multimodal inputs)
|
||||
|
||||
Args:
|
||||
text: Decoded model output text (from ``tokenizer.decode(...,
|
||||
skip_special_tokens=False)``).
|
||||
strict: If ``True``, only match the standard ``<|tool_call>`` format.
|
||||
If ``False`` (default), also try fallback patterns for
|
||||
known Gemma4 output variations.
|
||||
|
||||
Returns:
|
||||
A list of dicts, each with keys:
|
||||
- ``"name"``: The tool function name (e.g. ``"get_weather"``).
|
||||
- ``"arguments"``: A dict of argument name → value.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import (
|
||||
... parse_tool_calls
|
||||
... )
|
||||
>>> output = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
||||
>>> tool_calls = parse_tool_calls(output)
|
||||
>>> for tc in tool_calls:
|
||||
... print(f"Call: {tc['name']}({tc['arguments']})")
|
||||
"""
|
||||
results = []
|
||||
|
||||
# Tier 1: Standard format with special tokens.
|
||||
# <|tool_call>call:name{args}<tool_call|>
|
||||
# Note: Some Gemma4 models emit <turn|> instead of <tool_call|>.
|
||||
standard_pattern = r"<\|tool_call\>call:(\w+)\{(.*?)\}(?:<tool_call\|>|<turn\|>)"
|
||||
for match in re.finditer(standard_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
if results or strict:
|
||||
return results
|
||||
|
||||
# Tier 2: Fallback for known Gemma4 output variations.
|
||||
# Matches: <call>name{args}, call:name{args}, or bare call:name{args}<eos>
|
||||
fallback_pattern = r"(?:<call>|(?:^|\s)call:)(\w+)\{(.*?)\}"
|
||||
for match in re.finditer(fallback_pattern, text, re.DOTALL):
|
||||
name, args_str = match.group(1), match.group(2)
|
||||
results.append(
|
||||
{
|
||||
"name": name,
|
||||
"arguments": _parse_tool_arguments(args_str),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def has_tool_response_tag(text: str) -> bool:
|
||||
"""Check if model output properly ends with a tool response tag.
|
||||
|
||||
Some Gemma4 models sometimes emit ``<eos>`` instead of
|
||||
``<|tool_response>`` after a tool call. This helper detects
|
||||
whether the model used the proper termination, so callers can
|
||||
decide whether to inject ``<|tool_response>`` into the next prompt.
|
||||
|
||||
Args:
|
||||
text: Decoded model output text.
|
||||
|
||||
Returns:
|
||||
``True`` if the output ends with ``<|tool_response>``
|
||||
(proper behavior), ``False`` otherwise.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from vllm.model_executor.models.gemma4_utils import (
|
||||
... has_tool_response_tag
|
||||
... )
|
||||
>>> if not has_tool_response_tag(model_output):
|
||||
... # Model used <eos> instead — inject <|tool_response> manually
|
||||
... next_prompt = "<|tool_response>" + tool_result
|
||||
"""
|
||||
stripped = text.rstrip()
|
||||
return stripped.endswith(_TOOL_RESPONSE_START_TAG)
|
||||
@@ -111,6 +111,7 @@ _TEXT_GENERATION_MODELS = {
|
||||
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"Gemma3ForCausalLM": ("gemma3", "Gemma3ForCausalLM"),
|
||||
"Gemma3nForCausalLM": ("gemma3n", "Gemma3nForCausalLM"),
|
||||
"Gemma4ForCausalLM": ("gemma4", "Gemma4ForCausalLM"),
|
||||
"Qwen3NextForCausalLM": ("qwen3_next", "Qwen3NextForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"Glm4ForCausalLM": ("glm4", "Glm4ForCausalLM"),
|
||||
@@ -383,6 +384,7 @@ _MULTIMODAL_MODELS = {
|
||||
"gemma3n_mm",
|
||||
"Gemma3nForConditionalGeneration",
|
||||
),
|
||||
"Gemma4ForConditionalGeneration": ("gemma4_mm", "Gemma4ForConditionalGeneration"),
|
||||
"GlmAsrForConditionalGeneration": ("glmasr", "GlmAsrForConditionalGeneration"),
|
||||
"GLM4VForCausalLM": ("glm4v", "GLM4VForCausalLM"),
|
||||
"Glm4vForConditionalGeneration": ("glm4_1v", "Glm4vForConditionalGeneration"),
|
||||
|
||||
@@ -233,8 +233,15 @@ class AutoWeightsLoader:
|
||||
):
|
||||
"""
|
||||
Add tensor names that are not in the model params that may be in the
|
||||
safetensors, e.g., batch normalization stats.
|
||||
safetensors, e.g., batch normalization stats and registered buffers.
|
||||
"""
|
||||
# Add persistent registered buffers.
|
||||
# Non-persistent buffers are excluded, matching PyTorch state_dict().
|
||||
non_persistent = getattr(module, "_non_persistent_buffers_set", set())
|
||||
for buf_name, buf in module.named_buffers(recurse=False):
|
||||
if buf_name not in child_params and buf_name not in non_persistent:
|
||||
child_params[buf_name] = buf
|
||||
|
||||
if isinstance(
|
||||
module,
|
||||
(
|
||||
|
||||
Reference in New Issue
Block a user