[Model] Add Gemma3 GGUF multimodal support (#27772)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Luciano Martins
2025-11-18 13:56:29 -03:00
committed by GitHub
parent 49a986ecd4
commit c2612371ad
14 changed files with 752 additions and 86 deletions

View File

@@ -1,7 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from collections.abc import Callable, Mapping
from types import MappingProxyType
from typing import Any, Optional
import gguf
@@ -26,7 +27,11 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
)
from vllm.model_executor.models.utils import WeightsMapper
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
@@ -65,18 +70,70 @@ class GGUFConfig(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping
):
return UnquantizedLinearMethod()
return GGUFLinearMethod(self)
elif isinstance(layer, VocabParallelEmbedding):
if is_layer_skipped_gguf(
prefix, self.unquantized_modules, self.packed_modules_mapping
):
return UnquantizedEmbeddingMethod()
return GGUFEmbeddingMethod(self)
elif isinstance(layer, FusedMoE):
return GGUFMoEMethod(self, layer.moe_config)
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
"""
Interface for models to update module names referenced in
quantization configs in order to reflect the vllm model structure
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
return any(module_name in prefix for module_name in unquantized_modules)
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
structure of the qconfig) to vllm model structure
"""
if self.unquantized_modules is not None:
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
self.unquantized_modules
)
def is_layer_skipped_gguf(
prefix: str,
unquantized_modules: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
):
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = any(
shard_prefix in module_name for module_name in unquantized_modules
)
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = any(module_name in prefix for module_name in unquantized_modules)
assert is_skipped is not None
return is_skipped
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}