[Model] Support SigLIP encoder and alternative decoders for LLaVA models (#7153)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2024-08-06 16:55:31 +08:00
committed by GitHub
parent 9118217f58
commit 1f26efbb3a
14 changed files with 453 additions and 267 deletions

View File

@@ -1,22 +1,70 @@
from typing import Dict, List, Protocol, Tuple
from typing import Dict, Iterable, List, Optional, Protocol, Tuple
import torch
import torch.nn as nn
from torch.func import functional_call
from transformers import PretrainedConfig
from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal import BatchedTensors
from vllm.utils import is_pin_memory_available
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
"""
Helper function to load weights for inner vLLM models.
See also:
:ref:`init_vllm_registered_model`
"""
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight
def init_vllm_registered_model(
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],
quant_config: Optional[QuantizationConfig],
*,
lora_config: Optional[LoRAConfig] = None,
multimodal_config: Optional[MultiModalConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
) -> nn.Module:
"""
Helper function to initialize an inner model registered to vLLM,
based on the arguments passed to the outer vLLM model.
"""
model_class, _ = ModelRegistry.resolve_model_cls(hf_config.architectures)
return build_model(
model_class,
hf_config,
cache_config,
quant_config,
lora_config=lora_config,
multimodal_config=multimodal_config,
scheduler_config=scheduler_config,
)
def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeddings: BatchedTensors,
image_token_id: int) -> torch.Tensor:
"""
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
Merge ``vision_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder image tokens in
``input_ids``.
Note:
This updates `inputs_embeds` in place.
This updates ``inputs_embeds`` in place.
"""
mask = (input_ids == image_token_id)
num_expected_tokens = mask.sum()