diff --git a/requirements/common.txt b/requirements/common.txt index ad92ba3ad..1058ab91a 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -30,7 +30,7 @@ filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/31 partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec -gguf >= 0.13.0 +gguf >= 0.17.0 mistral_common[image] >= 1.8.5 opencv-python-headless >= 4.11.0 # required for video IO pyyaml diff --git a/tests/models/multimodal/generation/test_multimodal_gguf.py b/tests/models/multimodal/generation/test_multimodal_gguf.py new file mode 100644 index 000000000..e596b20c6 --- /dev/null +++ b/tests/models/multimodal/generation/test_multimodal_gguf.py @@ -0,0 +1,115 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Literal, NamedTuple + +import pytest +from huggingface_hub import hf_hub_download +from pytest import MarkDecorator + +from tests.quantization.utils import is_quant_method_supported +from vllm.assets.image import ImageAsset +from vllm.utils.torch_utils import set_default_torch_num_threads + +from ....conftest import PromptImageInput, VllmRunner +from ...utils import check_logprobs_close + + +class GGUFMMTestConfig(NamedTuple): + original_model: str + gguf_repo: str + gguf_backbone: str + gguf_mmproj: str + prompt: list[str] + mm_data: dict[Literal["images"], PromptImageInput] + max_model_len: int = 4096 + marks: list[MarkDecorator] = [] + + @property + def gguf_model(self): + hf_hub_download(self.gguf_repo, filename=self.gguf_mmproj) + return hf_hub_download(self.gguf_repo, filename=self.gguf_backbone) + + +GEMMA3_CONFIG = GGUFMMTestConfig( + original_model="google/gemma-3-4b-it", + gguf_repo="google/gemma-3-4b-it-qat-q4_0-gguf", + gguf_backbone="gemma-3-4b-it-q4_0.gguf", + gguf_mmproj="mmproj-model-f16-4B.gguf", + prompt=["Describe this image in detail:"], + mm_data={"images": [ImageAsset("stop_sign").pil_image]}, + marks=[pytest.mark.core_model], +) + +MODELS_TO_TEST = [GEMMA3_CONFIG] + + +def run_multimodal_gguf_test( + vllm_runner: type[VllmRunner], + model: GGUFMMTestConfig, + dtype: str, + max_tokens: int, + num_logprobs: int, +): + # Run gguf model. + with ( + set_default_torch_num_threads(1), + vllm_runner( + model_name=model.gguf_model, + enforce_eager=True, + tokenizer_name=model.original_model, + dtype=dtype, + max_model_len=model.max_model_len, + ) as gguf_model, + ): + gguf_outputs = gguf_model.generate_greedy_logprobs( + prompts=model.prompt, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + **model.mm_data, + ) + + # Run unquantized model. + with vllm_runner( + model_name=model.original_model, + enforce_eager=True, # faster tests + dtype=dtype, + max_model_len=model.max_model_len, + ) as original_model: + original_outputs = original_model.generate_greedy_logprobs( + prompts=model.prompt, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + **model.mm_data, + ) + + check_logprobs_close( + outputs_0_lst=original_outputs, + outputs_1_lst=gguf_outputs, + name_0="original", + name_1="gguf", + ) + + +@pytest.mark.skipif( + not is_quant_method_supported("gguf"), + reason="gguf is not supported on this GPU type.", +) +@pytest.mark.parametrize( + "model", + [ + pytest.param(test_config, marks=test_config.marks) + for test_config in MODELS_TO_TEST + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [10]) +def test_models( + vllm_runner: type[VllmRunner], + model: GGUFMMTestConfig, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + run_multimodal_gguf_test(vllm_runner, model, dtype, max_tokens, num_logprobs) diff --git a/tests/models/quantization/test_gguf.py b/tests/models/quantization/test_gguf.py index 5e2438857..3b9597507 100644 --- a/tests/models/quantization/test_gguf.py +++ b/tests/models/quantization/test_gguf.py @@ -78,6 +78,12 @@ DOLPHIN_CONFIG = GGUFTestConfig( gguf_filename="tinydolphin-2.8-1.1b.Q6_K.gguf", ) +GEMMA3_CONFIG = GGUFTestConfig( + original_model="google/gemma-3-270m-it", + gguf_repo="ggml-org/gemma-3-270m-it-qat-GGUF", + gguf_filename="gemma-3-270m-it-qat-Q4_0.gguf", +) + MODELS = [ # LLAMA_CONFIG, # broken: https://github.com/vllm-project/vllm/issues/19458 QWEN2_CONFIG, @@ -85,6 +91,7 @@ MODELS = [ GPT2_CONFIG, STABLELM_CONFIG, DOLPHIN_CONFIG, + GEMMA3_CONFIG, # STARCODER_CONFIG, # broken ] @@ -148,7 +155,7 @@ def check_model_outputs( "model", [pytest.param(test_config, marks=test_config.marks) for test_config in MODELS], ) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tp_size", [1]) diff --git a/vllm/config/model.py b/vllm/config/model.py index b3a28af6d..49fe0bcd9 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -33,10 +33,14 @@ from vllm.transformers_utils.config import ( try_get_generation_config, try_get_safetensors_metadata, try_get_tokenizer_config, + uses_custom_attention_masks, uses_mrope, ) +from vllm.transformers_utils.gguf_utils import ( + maybe_patch_hf_config_from_gguf, +) from vllm.transformers_utils.runai_utils import ObjectStorageModel, is_runai_obj_uri -from vllm.transformers_utils.utils import maybe_model_redirect +from vllm.transformers_utils.utils import check_gguf_file, maybe_model_redirect from vllm.utils.import_utils import LazyLoader from vllm.utils.torch_utils import common_broadcastable_dtype @@ -450,6 +454,12 @@ class ModelConfig: self.model = maybe_model_redirect(self.model) # The tokenizer is consistent with the model by default. if self.tokenizer is None: + if check_gguf_file(self.model): + raise ValueError( + "Using a tokenizer is mandatory when loading a GGUF model. " + "Please specify the tokenizer path or name using the " + "--tokenizer argument." + ) self.tokenizer = self.model if self.tokenizer_revision is None: self.tokenizer_revision = self.revision @@ -508,6 +518,10 @@ class ModelConfig: hf_overrides_kw=hf_overrides_kw, hf_overrides_fn=hf_overrides_fn, ) + hf_config = maybe_patch_hf_config_from_gguf( + self.model, + hf_config, + ) self.hf_config = hf_config if dict_overrides: @@ -1605,6 +1619,10 @@ class ModelConfig: def uses_mrope(self) -> bool: return uses_mrope(self.hf_config) + @property + def uses_custom_attention_masks(self) -> bool: + return uses_custom_attention_masks(self.hf_config) + @property def is_multimodal_model(self) -> bool: return self.multimodal_config is not None diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index caabcd0ca..42d7a6737 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -1,7 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections.abc import Callable +from collections.abc import Callable, Mapping +from types import MappingProxyType from typing import Any, Optional import gguf @@ -26,7 +27,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, + VocabParallelEmbedding, +) +from vllm.model_executor.models.utils import WeightsMapper from vllm.model_executor.utils import set_weight_attrs from vllm.utils.torch_utils import direct_register_custom_op @@ -65,18 +70,70 @@ class GGUFConfig(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): - if is_layer_skipped_gguf(prefix, self.unquantized_modules): + if is_layer_skipped_gguf( + prefix, self.unquantized_modules, self.packed_modules_mapping + ): return UnquantizedLinearMethod() return GGUFLinearMethod(self) elif isinstance(layer, VocabParallelEmbedding): + if is_layer_skipped_gguf( + prefix, self.unquantized_modules, self.packed_modules_mapping + ): + return UnquantizedEmbeddingMethod() return GGUFEmbeddingMethod(self) elif isinstance(layer, FusedMoE): return GGUFMoEMethod(self, layer.moe_config) return None + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + """ + Interface for models to update module names referenced in + quantization configs in order to reflect the vllm model structure -def is_layer_skipped_gguf(prefix: str, unquantized_modules: list[str]): - return any(module_name in prefix for module_name in unquantized_modules) + :param hf_to_vllm_mapper: maps from hf model structure (the assumed + structure of the qconfig) to vllm model structure + """ + if self.unquantized_modules is not None: + self.unquantized_modules = hf_to_vllm_mapper.apply_list( + self.unquantized_modules + ) + + +def is_layer_skipped_gguf( + prefix: str, + unquantized_modules: list[str], + fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), +): + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + proj_name = prefix.split(".")[-1] + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = any( + shard_prefix in module_name for module_name in unquantized_modules + ) + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = any(module_name in prefix for module_name in unquantized_modules) + + assert is_skipped is not None + return is_skipped UNQUANTIZED_TYPES = {WeightType.F32, WeightType.F16, WeightType.BF16} diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 7db1fc167..2416836be 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -7,10 +7,11 @@ import gguf import torch import torch.nn as nn from huggingface_hub import hf_hub_download -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoModelForImageTextToText from vllm.config import ModelConfig, VllmConfig from vllm.config.load import LoadConfig +from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.utils import ( initialize_model, @@ -21,8 +22,11 @@ from vllm.model_executor.model_loader.weight_utils import ( get_gguf_weight_type_map, gguf_quant_weights_iterator, ) +from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal from vllm.utils.torch_utils import set_default_torch_dtype +logger = init_logger(__name__) + class GGUFModelLoader(BaseModelLoader): """ @@ -67,7 +71,15 @@ class GGUFModelLoader(BaseModelLoader): https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. """ config = model_config.hf_config + # Get text config to handle both nested (multimodal) and flat + # (text-only) config structures. For multimodal models like + # Gemma3Config, this returns config.text_config. For text-only + # models, this returns config itself. + text_config = config.get_text_config() model_type = config.model_type + is_multimodal = ( + hasattr(config, "vision_config") and config.vision_config is not None + ) gguf_to_hf_name_map = {} # hack: ggufs have a different name than transformers if model_type == "cohere": @@ -115,24 +127,167 @@ class GGUFModelLoader(BaseModelLoader): break if arch is None: raise RuntimeError(f"Unknown gguf model_type: {model_type}") - num_layers = config.num_hidden_layers - name_map = gguf.get_tensor_name_map(arch, num_layers) + text_num_layers = text_config.num_hidden_layers + text_name_map = gguf.get_tensor_name_map(arch, text_num_layers) + + if is_multimodal: + mm_proj_arch = gguf.MODEL_ARCH.MMPROJ + vision_num_layers = config.vision_config.num_hidden_layers + vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers) + else: + vision_name_map = None + + # Create dummy model to extract parameter names + # For multimodal: use AutoModelForImageTextToText to get + # language + vision + projector params + # For text-only: use AutoModelForCausalLM to get language model params + auto_cls = ( + AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM + ) with torch.device("meta"): - dummy_model = AutoModelForCausalLM.from_config( + dummy_model = auto_cls.from_config( config, trust_remote_code=model_config.trust_remote_code ) - state_dict = dummy_model.state_dict() + state_dict = dummy_model.state_dict() + if hf_checkpoint_map := getattr( + dummy_model, "_checkpoint_conversion_mapping", None + ): + + def revert_hf_rename(name: str) -> str: + for original_name, hf_name in hf_checkpoint_map.items(): + if hf_name in name: + name = name.replace(hf_name, original_name).lstrip("^") + return name + + state_dict = { + revert_hf_rename(name): tensor for name, tensor in state_dict.items() + } + + def find_hf_name_in_tensor_map(hf_name: str) -> str | None: + """ + Map HuggingFace parameter name to GGUF tensor name. + + This function handles the mismatch between HF parameter naming + conventions and gguf-py's expected format: + 1. Strips 'model.' prefix (common in multimodal models) + 2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility) + 3. Searches vision_name_map for multimodal parameters + 4. Falls back to text_name_map for language model parameters + + Args: + hf_name: Full HuggingFace parameter name (e.g., + 'model.multi_modal_projector.mm_soft_emb_norm.weight') + + Returns: + GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight') + or None if no mapping found + """ + # Strip 'language_model.' prefix for multimodal models - gguf-py + # tensor mappings expect parameter names without this prefix. + # Note: 'model.' prefix should be KEPT for text-only models as + # gguf-py expects it. + if hf_name.startswith("language_model."): + hf_name = hf_name[15:] # Remove 'language_model.' + + # Parse parameter name and suffix + if hf_name.endswith((".weight", ".bias")): + base_name, suffix = hf_name.rsplit(".", 1) + else: + base_name, suffix = hf_name, "" + # Handle '_weight' suffix (Gemma3 naming: parameter ends with + # '_weight' instead of '.weight') + if base_name.endswith("_weight"): + base_name = base_name[:-7] # Remove '_weight' + suffix = "weight" + + gguf_name = None + # Priority 1: Search vision/projector parameters for multimodal models + if vision_name_map is not None: + gguf_name = vision_name_map.get_name(base_name) + + # Priority 2: Search text backbone parameters + if gguf_name is None: + gguf_name = text_name_map.get_name(base_name) + + if gguf_name is None: + return None + + return gguf_name + "." + suffix + + # Build mapping and track unmapped parameters + unmapped_params = [] for hf_name in state_dict: - name, suffix = hf_name.rsplit(".", 1) - gguf_name = name_map.get_name(name) - gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name + gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name) + + # Track mapping success + if gguf_name_with_suffix is not None: + gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name + logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name) + elif hf_name not in gguf_to_hf_name_map.values(): + # Parameter not in manual overrides either + unmapped_params.append(hf_name) + + # All parameters must be mapped: both vision/projector and backbone + if unmapped_params: + raise RuntimeError( + f"Failed to map GGUF parameters " + f"({len(unmapped_params)}): " + f"{unmapped_params}" + ) return gguf_to_hf_name_map + def _get_gguf_weight_type( + self, + model_config: ModelConfig, + model_name_or_path: str, + gguf_to_hf_name_map: dict[str, str], + ) -> dict[str, str]: + weight_type_map = get_gguf_weight_type_map( + model_config.model, gguf_to_hf_name_map + ) + is_multimodal = hasattr(model_config.hf_config, "vision_config") + if is_multimodal: + mmproj_file = detect_gguf_multimodal(model_name_or_path) + assert mmproj_file is not None, ( + "Could not find mm_proj file for multimodal GGUF model" + ) + logger.info("Loading extra mm_proj weights from %s...", mmproj_file) + mm_proj_weight_type_map = get_gguf_weight_type_map( + mmproj_file, gguf_to_hf_name_map + ) + weight_type_map.update(mm_proj_weight_type_map) + return weight_type_map + def _get_weights_iterator( - self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str] + self, + model_config: ModelConfig, + model_name_or_path: str, + gguf_to_hf_name_map: dict[str, str], ) -> Generator[tuple[str, torch.Tensor], None, None]: - return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) + """ + Iterate over GGUF model weights, loading from both main model file and + mmproj.gguf for multimodal Gemma3 models. + + For Gemma3 multimodal GGUF models: + - Main file (gemma-3-*.gguf): Language model weights (model.*) + - mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*) + + Yields: + Tuples of (parameter_name, tensor) for all model weights + """ + hf_config = model_config.hf_config + is_multimodal = hasattr(hf_config, "vision_config") + + if is_multimodal: + # Load mm_proj (mm_encoder + projector) for multimodal weights + mmproj_file = detect_gguf_multimodal(model_name_or_path) + assert mmproj_file is not None, ( + "Could not find mm_proj file for multimodal GGUF model" + ) + yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map) + + yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) @@ -141,7 +296,7 @@ class GGUFModelLoader(BaseModelLoader): local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) model.load_weights( - self._get_weights_iterator(local_model_path, gguf_weights_map) + self._get_weights_iterator(model_config, local_model_path, gguf_weights_map) ) def load_model( @@ -156,14 +311,19 @@ class GGUFModelLoader(BaseModelLoader): ): model_config.hf_config.update({"tie_word_embeddings": True}) - weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map) - + weight_type_map = self._get_gguf_weight_type( + model_config, local_model_path, gguf_weights_map + ) # filter out unquantized modules to skip unquant_names = [ name.removesuffix(".weight") for name, weight_type in weight_type_map.items() - if weight_type == "F32" and name.endswith(".weight") + if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight") ] + logger.debug( + "GGUF unquantized modules: %s", + unquant_names, + ) vllm_config.quant_config.unquantized_modules.extend(unquant_names) target_device = torch.device(device_config.device) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 93986e5f2..89634cbf4 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -836,7 +836,11 @@ def gguf_quant_weights_iterator( ) -> Generator[tuple[str, torch.Tensor], None, None]: """ Iterate over the quant weights in the model gguf files and convert - them to torch tensors + them to torch tensors. + Be careful of the order of yielding weight types and weights data, + we have to yield all weight types first before yielding any weights. + Otherwise it would cause issue when loading weights with for packed + layer with different quant types. """ reader = gguf.GGUFReader(gguf_file) @@ -846,7 +850,7 @@ def gguf_quant_weights_iterator( weight_type = tensor.tensor_type name = gguf_to_hf_name_map[tensor.name] - if weight_type.name != "F32": + if weight_type.name not in ("F32", "BF16", "F16"): weight_type_name = name.replace("weight", "qweight_type") weight_type = torch.tensor(weight_type) yield weight_type_name, weight_type @@ -856,7 +860,7 @@ def gguf_quant_weights_iterator( weight = tensor.data weight_type = tensor.tensor_type name = gguf_to_hf_name_map[tensor.name] - if weight_type.name != "F32": + if weight_type.name not in ("F32", "BF16", "F16"): name = name.replace("weight", "qweight") param = torch.tensor(weight) yield name, param diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 8e2bbe8f7..fe83c8b63 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias +from typing import Annotated, Any, Literal import torch from torch import nn @@ -20,12 +20,7 @@ from vllm.multimodal.inputs import ( MultiModalFieldConfig, MultiModalKwargsItems, ) -from vllm.multimodal.parse import ( - ImageEmbeddingItems, - ImageProcessorItems, - ImageSize, - MultiModalDataItems, -) +from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, @@ -76,15 +71,7 @@ class Gemma3ImagePixelInputs(TensorSchema): num_patches: Annotated[torch.Tensor, TensorShape("bn")] -class Gemma3ImageEmbeddingInputs(TensorSchema): - type: Literal["image_embeds"] = "image_embeds" - image_embeds: Annotated[ - torch.Tensor, - TensorShape("ni", "nf", "hs"), - ] - - -Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs +Gemma3ImageInputs = Gemma3ImagePixelInputs class Gemma3ProcessingInfo(BaseProcessingInfo): @@ -191,9 +178,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): def get_image_repl( self, *, - image_width: int | None, - image_height: int | None, - num_crops: int | None = None, + image_width: int, + image_height: int, processor: Gemma3Processor | None, ) -> PromptUpdateDetails[str]: if processor is None: @@ -201,13 +187,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): boi_token = processor.boi_token - if num_crops is None: - assert image_width is not None and image_height is not None - num_crops = self.get_num_crops( - image_width=image_width, - image_height=image_height, - processor=processor, - ) + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) if num_crops == 0: image_text = boi_token @@ -337,7 +321,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -350,19 +333,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): image_token = hf_processor.boi_token def get_replacement_gemma3(item_idx: int): - images = mm_items.get_items( - "image", (ImageEmbeddingItems, ImageProcessorItems) - ) - - if isinstance(images, ImageEmbeddingItems): - # For image embedding inputs, only support no crops cases - # since it's not supported in hf processor anyway - return self.info.get_image_repl( - image_width=None, - image_height=None, - num_crops=0, - processor=hf_processor, - ) + images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) return self.info.get_image_repl( @@ -586,19 +557,17 @@ class Gemma3ForConditionalGeneration( pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) + assert image_embeds is None, "Gemma3 does not support image_embeds." + if pixel_values is None: + return None - if pixel_values is not None: - image_size = self.config.vision_config.image_size - return Gemma3ImagePixelInputs( - pixel_values=pixel_values, - num_patches=num_patches, - resolve_bindings={"h": image_size, "w": image_size}, - ) - elif image_embeds is not None: - return Gemma3ImageEmbeddingInputs( - image_embeds=image_embeds, - type="image_embeds", - ) + image_size = self.config.vision_config.image_size + + return Gemma3ImagePixelInputs( + pixel_values=pixel_values, + num_patches=num_patches, + resolve_bindings={"h": image_size, "w": image_size}, + ) def _image_pixels_to_features( self, @@ -610,9 +579,7 @@ class Gemma3ForConditionalGeneration( def _process_image_input( self, image_input: Gemma3ImageInputs, - ) -> torch.Tensor | list[torch.Tensor]: - if image_input["type"] == "image_embeds": - return image_input["image_embeds"] + ) -> list[torch.Tensor]: assert self.vision_tower is not None pixel_values = image_input["pixel_values"] @@ -629,13 +596,33 @@ class Gemma3ForConditionalGeneration( def get_language_model(self) -> torch.nn.Module: return self.language_model - def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings: image_input = self._parse_and_validate_image_input(**kwargs) if image_input is None: return [] return self._process_image_input(image_input) + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = True, + ) -> torch.Tensor: + # Early return for text-only inference (no multimodal data) + if multimodal_embeddings is None or is_multimodal is None: + return super().embed_input_ids(input_ids) + + # Use interface default with OOV handling enabled + return super().embed_input_ids( + input_ids, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=is_multimodal, + handle_oov_mm_token=handle_oov_mm_token, + ) + def forward( self, input_ids: torch.Tensor, @@ -657,6 +644,79 @@ class Gemma3ForConditionalGeneration( return hidden_states + def generate_attention_masks( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + mask_dtype: torch.dtype, + ) -> dict[str, Any]: + """Generate custom attention masks for Gemma3 multimodal inputs. + + This is called by V1 engine's gpu_model_runner during preprocessing + to generate attention masks that allow bidirectional attention between + image tokens while maintaining causal attention for text. + """ + # NOTE(woosuk): Here, we distinguish the sequences by the position id 0. + # This is a HACK. Fix this. + start_indices = (positions == 0).cpu().nonzero() + num_seqs = len(start_indices) + seq_lens = [] + for i in range(num_seqs): + start_idx = start_indices[i] + end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids) + seq_lens.append(end_idx - start_idx) + + global_attn_masks = [] + local_attn_masks = [] + start_idx = 0 + for seq_idx, seq_len in enumerate(seq_lens): + end_idx = start_idx + seq_len + input_token_ids = input_ids[start_idx:end_idx] + + # Find image token positions + img_pos = input_token_ids == self.config.image_token_index + + start_idx = end_idx + + # Create a global causal mask + global_attn_mask = torch.empty( + 1, + 1, + seq_len, + seq_len, + dtype=mask_dtype, + device=input_ids.device, + ) + global_attn_mask.fill_(float("-inf")) + # Fill the lower triangle with 0 (causal attention) + global_attn_mask = global_attn_mask.triu(diagonal=1) + + # Enable bidirectional attention between image tokens + img_mask = torch.zeros_like(global_attn_mask) + img_mask[:, :, :, img_pos] += 1 + img_mask[:, :, img_pos, :] += 1 + global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask) + global_attn_masks.append(global_attn_mask) + + # GGUF compatibility: config might be Gemma3TextConfig directly + text_config = getattr(self.config, "text_config", self.config) + sliding_window = text_config.sliding_window + if sliding_window is not None: + # Create a local causal mask with sliding window (1024) + local_attn_mask = torch.ones_like(global_attn_mask) + local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window) + local_attn_mask = torch.where( + local_attn_mask == 0, global_attn_mask, float("-inf") + ) + local_attn_masks.append(local_attn_mask) + + return { + "has_images": True, + "seq_lens": seq_lens, + "global_attn_masks": global_attn_masks, + "local_attn_masks": local_attn_masks, + } + def prepare_attn_masks( self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index b175dd60c..42d906d08 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -827,6 +827,7 @@ class SiglipVisionModel(nn.Module): ) -> None: super().__init__() + self.quant_config = quant_config self.vision_model = SiglipVisionTransformer( config, quant_config, @@ -911,12 +912,38 @@ class SiglipVisionModel(nn.Module): break else: param = params_dict[name] + param = maybe_swap_ffn_param( + name, param, loaded_weight, params_dict, self.quant_config + ) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params +def maybe_swap_ffn_param( + name: str, + param: torch.Tensor, + loaded_weight: torch.Tensor, + params_dict: dict[str, torch.Tensor], + quant_config: QuantizationConfig, +) -> torch.Tensor: + if not (quant_config and quant_config.get_name() == "gguf") or ".fc" not in name: + return param + # Some GGUF models have fc1 and fc2 weights swapped + tp_size = get_tensor_model_parallel_world_size() + output_dim = getattr(param, "output_dim", 0) + output_size = param.size(output_dim) * tp_size + weight_out_size = loaded_weight.size(output_dim) + if ".fc1." in name and output_size != weight_out_size: + new_name = name.replace(".fc1.", ".fc2.") + param = params_dict[new_name] + elif ".fc2." in name and output_size != weight_out_size: + new_name = name.replace(".fc2.", ".fc1.") + param = params_dict[new_name] + return param + + # Adapted from: https://github.com/huggingface/transformers/blob/v4.54.1/src/transformers/models/siglip/modeling_siglip.py#L200 class SiglipTextEmbeddings(nn.Module): def __init__(self, config: SiglipTextConfig): diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 49250e071..ac4a71648 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -477,6 +477,17 @@ def is_interleaved(config: PretrainedConfig) -> bool: return False +def uses_custom_attention_masks(config: PretrainedConfig) -> bool: + """Detect if model uses custom attention mask generation for multimodal. + + Some multimodal models require custom attention masks that enable + bidirectional attention between image tokens while maintaining causal + attention for text tokens. Currently applies to Gemma3 multimodal models. + """ + architectures = getattr(config, "architectures", []) + return "Gemma3ForConditionalGeneration" in architectures + + def _maybe_update_auto_config_kwargs(kwargs: dict[str, Any], model_type: str): """ Update kwargs for AutoConfig initialization based on model_type diff --git a/vllm/transformers_utils/gguf_utils.py b/vllm/transformers_utils/gguf_utils.py new file mode 100644 index 000000000..2bf59c91a --- /dev/null +++ b/vllm/transformers_utils/gguf_utils.py @@ -0,0 +1,166 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""GGUF utility functions.""" + +from pathlib import Path + +import gguf +from gguf.constants import Keys, VisionProjectorType +from transformers import Gemma3Config, PretrainedConfig, SiglipVisionConfig + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def detect_gguf_multimodal(model: str) -> Path | None: + """Check if GGUF model has multimodal projector file. + + Args: + model: Model path string + + Returns: + Path to mmproj file if found, None otherwise + """ + if not model.endswith(".gguf"): + return None + + try: + model_path = Path(model) + if not model_path.is_file(): + return None + + model_dir = model_path.parent + mmproj_patterns = ["mmproj.gguf", "mmproj-*.gguf", "*mmproj*.gguf"] + for pattern in mmproj_patterns: + mmproj_files = list(model_dir.glob(pattern)) + if mmproj_files: + return mmproj_files[0] + return None + except Exception: + return None + + +def extract_vision_config_from_gguf(mmproj_path: str) -> "SiglipVisionConfig | None": + """Extract vision config parameters from mmproj.gguf metadata. + + Reads vision encoder configuration from GGUF metadata fields using + standardized GGUF constants. Automatically detects the projector type + (e.g., gemma3, llama4) and applies model-specific parameters accordingly. + + The function extracts standard CLIP vision parameters from GGUF metadata + and applies projector-type-specific customizations. For unknown projector + types, it uses safe defaults from SiglipVisionConfig. + + Args: + mmproj_path: Path to mmproj.gguf file (str or Path) + + Returns: + SiglipVisionConfig if extraction succeeds, None if any required + field is missing from the GGUF metadata + + Raises: + Exception: Exceptions from GGUF reading (file not found, corrupted + file, etc.) propagate directly from gguf.GGUFReader + """ + reader = gguf.GGUFReader(str(mmproj_path)) + + # Detect projector type to apply model-specific parameters + projector_type = None + projector_type_field = reader.get_field(Keys.Clip.PROJECTOR_TYPE) + if projector_type_field: + try: + projector_type = bytes(projector_type_field.parts[-1]).decode("utf-8") + except (AttributeError, UnicodeDecodeError) as e: + logger.warning("Failed to decode projector type from GGUF: %s", e) + + # Map GGUF field constants to SiglipVisionConfig parameters. + # Uses official GGUF constants from gguf-py for standardization. + # Format: {gguf_constant: (param_name, dtype)} + VISION_CONFIG_FIELDS = { + Keys.ClipVision.EMBEDDING_LENGTH: ("hidden_size", int), + Keys.ClipVision.FEED_FORWARD_LENGTH: ("intermediate_size", int), + Keys.ClipVision.BLOCK_COUNT: ("num_hidden_layers", int), + Keys.ClipVision.Attention.HEAD_COUNT: ("num_attention_heads", int), + Keys.ClipVision.IMAGE_SIZE: ("image_size", int), + Keys.ClipVision.PATCH_SIZE: ("patch_size", int), + Keys.ClipVision.Attention.LAYERNORM_EPS: ("layer_norm_eps", float), + } + + # Extract and validate all required fields + config_params = {} + for gguf_key, (param_name, dtype) in VISION_CONFIG_FIELDS.items(): + field = reader.get_field(gguf_key) + if field is None: + logger.warning( + "Missing required vision config field '%s' in mmproj.gguf", + gguf_key, + ) + return None + # Extract scalar value from GGUF field and convert to target type + config_params[param_name] = dtype(field.parts[-1]) + + # Apply model-specific parameters based on projector type + if projector_type == VisionProjectorType.GEMMA3: + # Gemma3 doesn't use the vision pooling head (multihead attention) + # This is a vLLM-specific parameter used in SiglipVisionTransformer + config_params["vision_use_head"] = False + logger.info("Detected Gemma3 projector, disabling vision pooling head") + # Add other projector-type-specific customizations here as needed + # elif projector_type == VisionProjectorType.LLAMA4: + # config_params["vision_use_head"] = ... + + # Create config with extracted parameters + # Note: num_channels and attention_dropout use SiglipVisionConfig defaults + # (3 and 0.0 respectively) which are correct for all models + config = SiglipVisionConfig(**config_params) + + if projector_type: + logger.info( + "Extracted vision config from mmproj.gguf (projector_type: %s)", + projector_type, + ) + else: + logger.info("Extracted vision config from mmproj.gguf metadata") + + return config + + +def maybe_patch_hf_config_from_gguf( + model: str, + hf_config: PretrainedConfig, +) -> PretrainedConfig: + """Patch HF config for GGUF models. + + Applies GGUF-specific patches to HuggingFace config: + 1. For multimodal models: patches architecture and vision config + 2. For all GGUF models: overrides vocab_size from embedding tensor + + This ensures compatibility with GGUF models that have extended + vocabularies (e.g., Unsloth) where the GGUF file contains more + tokens than the HuggingFace tokenizer config specifies. + + Args: + model: Model path string + hf_config: HuggingFace config to patch in-place + + Returns: + Updated HuggingFace config + """ + # Patch multimodal config if mmproj.gguf exists + mmproj_path = detect_gguf_multimodal(model) + if mmproj_path is not None: + vision_config = extract_vision_config_from_gguf(str(mmproj_path)) + + # Create HF config for Gemma3 multimodal + text_config = hf_config.get_text_config() + is_gemma3 = hf_config.model_type in ("gemma3", "gemma3_text") + if vision_config is not None and is_gemma3: + new_hf_config = Gemma3Config.from_text_vision_configs( + text_config=text_config, + vision_config=vision_config, + architectures=["Gemma3ForConditionalGeneration"], + ) + hf_config = new_hf_config + + return hf_config diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index b3469c1b1..8deacb5b0 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -18,7 +18,7 @@ from transformers.processing_utils import ProcessorMixin from transformers.video_processing_utils import BaseVideoProcessor from typing_extensions import TypeVar -from vllm.transformers_utils.utils import convert_model_repo_to_path +from vllm.transformers_utils.utils import check_gguf_file, convert_model_repo_to_path from vllm.utils.func_utils import get_allowed_kwarg_only_overrides if TYPE_CHECKING: @@ -236,9 +236,20 @@ def cached_processor_from_config( processor_cls: type[_P] | tuple[type[_P], ...] = ProcessorMixin, **kwargs: Any, ) -> _P: + if check_gguf_file(model_config.model): + assert not check_gguf_file(model_config.tokenizer), ( + "For multimodal GGUF models, the original tokenizer " + "should be used to correctly load processor." + ) + model = model_config.tokenizer + revision = model_config.tokenizer_revision + else: + model = model_config.model + revision = model_config.revision + return cached_get_processor_without_dynamic_kwargs( - model_config.model, - revision=model_config.revision, + model, + revision=revision, trust_remote_code=model_config.trust_remote_code, processor_cls=processor_cls, # type: ignore[arg-type] **_merge_mm_kwargs(model_config, processor_cls, **kwargs), @@ -339,9 +350,19 @@ def cached_image_processor_from_config( model_config: "ModelConfig", **kwargs: Any, ): + if check_gguf_file(model_config.model): + assert not check_gguf_file(model_config.tokenizer), ( + "For multimodal GGUF models, the original tokenizer " + "should be used to correctly load image processor." + ) + model = model_config.tokenizer + revision = model_config.tokenizer_revision + else: + model = model_config.model + revision = model_config.revision return cached_get_image_processor( - model_config.model, - revision=model_config.revision, + model, + revision=revision, trust_remote_code=model_config.trust_remote_code, **_merge_mm_kwargs(model_config, AutoImageProcessor, **kwargs), ) diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py index 1ae42ba62..901a64d9d 100644 --- a/vllm/transformers_utils/utils.py +++ b/vllm/transformers_utils/utils.py @@ -27,6 +27,7 @@ def is_cloud_storage(model_or_path: str) -> bool: return is_s3(model_or_path) or is_gcs(model_or_path) +@cache def check_gguf_file(model: str | PathLike) -> bool: """Check if the file is a GGUF model.""" model = Path(model) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0102ca473..67f575f92 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -324,6 +324,7 @@ class GPUModelRunner( # Multi-modal data support self.mm_registry = MULTIMODAL_REGISTRY self.uses_mrope = model_config.uses_mrope + self.uses_custom_attention_masks = model_config.uses_custom_attention_masks self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs( model_config ) @@ -2346,6 +2347,24 @@ class GPUModelRunner( **self._init_model_kwargs(num_scheduled_tokens), **self._extract_mm_kwargs(scheduler_output), } + + # Generate custom attention masks for models that require them. + # V1 pre-generates embeddings, so forward() skips prepare_attn_masks(). + # Check mm_features (mm_embeds is empty during decode). + has_mm_features = any( + req_state.mm_features for req_state in self.requests.values() + ) + if ( + self.uses_custom_attention_masks + and has_mm_features + and hasattr(self.model, "generate_attention_masks") + ): + mask_kwargs = self.model.generate_attention_masks( + self.input_ids.gpu[:num_scheduled_tokens], + self.positions.gpu[:num_scheduled_tokens], + mask_dtype=self.model.dtype, + ) + model_kwargs.update(mask_kwargs) elif self.enable_prompt_embeds and is_first_rank: # Get the input embeddings for the tokens that are not input embeds, # then put them into the appropriate positions.