[Model] Add Gemma3 GGUF multimodal support (#27772)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
|
||||
import gguf
|
||||
@@ -26,7 +27,11 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
@@ -65,18 +70,70 @@ class GGUFConfig(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_gguf(prefix, self.unquantized_modules):
|
||||
if is_layer_skipped_gguf(
|
||||
prefix, self.unquantized_modules, self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return GGUFLinearMethod(self)
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if is_layer_skipped_gguf(
|
||||
prefix, self.unquantized_modules, self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
return GGUFEmbeddingMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
return GGUFMoEMethod(self, layer.moe_config)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
"""
|
||||
Interface for models to update module names referenced in
|
||||
quantization configs in order to reflect the vllm model structure
|
||||
|
||||
def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]):
|
||||
return any(module_name in prefix for module_name in unquantized_modules)
|
||||
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
||||
structure of the qconfig) to vllm model structure
|
||||
"""
|
||||
if self.unquantized_modules is not None:
|
||||
self.unquantized_modules = hf_to_vllm_mapper.apply_list(
|
||||
self.unquantized_modules
|
||||
)
|
||||
|
||||
|
||||
def is_layer_skipped_gguf(
|
||||
prefix: str,
|
||||
unquantized_modules: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
):
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
proj_name = prefix.split(".")[-1]
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = any(
|
||||
shard_prefix in module_name for module_name in unquantized_modules
|
||||
)
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
elif is_shard_skipped != is_skipped:
|
||||
raise ValueError(
|
||||
f"Detected some but not all shards of {prefix} "
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision."
|
||||
)
|
||||
else:
|
||||
is_skipped = any(module_name in prefix for module_name in unquantized_modules)
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
|
||||
UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16}
|
||||
|
||||
Reference in New Issue
Block a user